Computing gradients with respect to weights in custom Conv2d


Im currently working on implementing Conv2d from scratch without using autograd. So i already got the forwardpass working using unfold. But when it comes to the backwardpass something goes wrong. As i understood it after you unfold your input and kernel the whole steps for the forwardpass and backwardpass are basically the same as in normal linear networks. But if i try to implement it like this i get weird shapes. Heres the code im currently using.

import torch
# as an input im using a tensor with the size of a mnist digit
img = torch.randn(1 ,1 ,28 ,28)
# kernel with 1 input dim and 1 output dim
kernel = torch.randn(1 ,1 ,3 ,3)
# unfold with kernel size 3 and padding 1
unfold = torch.nn.Unfold(kernel_size=(3,3) ,padding=1)
# unfolded data
img_unfolded = unfold(img)
# forwardpass with kernel and unfolded data
# reshape kernel to (out_dim ,-1)
output = kernel.view(1 ,-1).matmul(img_unfolded)
# output with shape (1 ,1 ,784) (then youd reshape to 1,1,28,28)
# to have the right dimensions for the next layer

now assume that the gradients of the outputlayer are already computed as randn(1 ,1 ,784)

# gradients already calculated and with shape of the output layer
grads_output = torch.randn(1 ,1 ,784)
# so now how to i get to the gradients with respect to the weights? 
# I tried it with img_unfolded transposed
# times output but the shape is just weird
grads_weights = img_unfolded.T.matmul(output)
# shape is (784, 9, 784) which is clearly wrong becuase the 
# right shape would be (1,9)

but when i try to compute the gradients of the next layers output the result is as im expecting it

# transposed reshaped kernel times the gradients of the output
grads_next_layer = kernel.view(1 ,-1).T.matmul(grads_output)
# as said it has the right shape of (1 ,9 ,784)

so my question is how do i correctly calculate the gradients with respect to the weights?

thanks for taking your time to read this and have a nice day :slight_smile:

Out[13]: torch.Size([1, 9, 784])
Out[14]: torch.Size([784, 9, 1])

so, .T does the wrong permutation

gradient of matmul(kernel, img_unfolded) wrt kernel is, 3d version of which should use shapes (1,1,784) @ (1, 784, 9)

Iā€™d suggest using .transpose(-2,-1)

1 Like

Thanks so much for the solution! Have a nice day