How was conv2d implemented in pytorch?

Is there any python code which implements the forward pass of the following one?

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
1 Like

Hi,

No there is none.
There are specific cpp/cuda kernels to do this. In the case of gpu, it mostly uses cudnn own implementation that use many different algorithms itself.

The classic implementation that we use on CPU is based on matrix multiplication and would look like this in python (note that this will use more memory and be slower that the direct conv2d implementation):

import torch 

inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)

# Handmade conv
inp_unf = torch.nn.functional.unfold(inp, (4, 5))
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)

# Check that the result is correct
print((torch.nn.functional.conv2d(inp, w) - out).abs().max())
3 Likes

Thanks for the answer. Do you know if this handmade conv will give exactly same results with Conv2d? When we train it for some task, will they be exactly same? Or are there any other hidden optimizations/different way of calculating gradients in the cuda implementation?

Hi,

It will give the same result up to numerical precision of floating point numbers.
So it might not be bit equivalent no.

I found the gradient of weight using this method and it matched the pytorch one but the gradient of the input doesn’t match the pytorch one. Could you please help me to implement it using fold and unfold.