ConvTranspose2d using unfold

Hello,
Supose i have an matrix img and a kernel kernel.
Is it possible to do a transposed convolution doing a matrix multiplication.
I know that when you unroll the kernel you have to transpose this but when unrolling the input i cant
figure it out.

import torch
# as an input im using a tensor with the size of a mnist digit
img = torch.randn(1 ,1 ,28 ,28)
# kernel with 1 input dim and 1 output dim
kernel = torch.randn(1 ,1 ,3 ,3)
# unfold with kernel size 3 ,stride 2 , padding 1
unfold = torch.nn.Unfold(kernel_size=(3,3) ,stride=2 ,padding=1)
# unfolded data
img_unfolded = unfold(img)

"""
so how to do the convtranspose using img_unfolded and kernel?
"""

thanks for looking at it and have a nice day :slight_smile:

1 Like

If your goal is to do transpose convolution, why don’t you use nn.functional.conv_transpose2d? torch.nn.functional — PyTorch 1.8.1 documentation

beacuse i want to do it from scratch to really understand whats happening behind those function calls.

Hello,

I have also been needing to do this exact thing myself (for various reasons) for a little while and struggling to figure it out. Honestly, it’s way harder to find answers for this than it should be. I have read a lot of material on convolution, pytorch unfold, convtranspose2d, and cnn gradients. Finally, I just got it.

The answer is fortunately actually very simple, it’s just that it seems everyone has a different view of this operation (upsampling, cnn gradient, deconvolution, etc) that doesn’t quite explain everything. So this answer isn’t very “googleable”. Note this is the inefficient way of doing things - to do this using unfold, we have to add a bunch of padding to all the sides. There are more efficient implementations, but this is the best vectorized implementation I can come up with.

import torch
import torch.nn.functional as F
img = torch.randn(1 ,50 ,28 ,28)
kernel = torch.randn(30,50 ,3 ,3)
true_convt2d = F.conv_transpose2d(img, kernel.transpose(0,1))

pad0 = 3-1 # to explicitly show calculation of convtranspose2d padding
pad1 = 3-1
inp_unf = torch.nn.functional.unfold(img, (3,3), padding=(pad0,pad1))
w = torch.rot90(kernel, 2, [2,3])
# this is done the same way as forward convolution
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = out_unf.view(true_convt2d.shape)
print((true_convt2d-out).abs().max())
print(true_convt2d.abs().max())

Running this code will give a small error value around 1e-5, but you can see the magnitude of the output is around 90 so it’s close enough. I think this is due to optimization in the backend.

I hope this answer becomes “googleable” for others looking for this information. I am an new user so I can apparently only put 2 links in a post. If you want more, please DM me. There is a formula for calculating padding on data science stack exchange (though it’s not too hard to figure out) if you search “how-to-calculate-the-output-shape-of-conv2d-transpose”.

Disclaimer: I am pretty sure this is correct, but it still could be wrong. Also, this doesn’t take into account padding, strides, dilation, or groups.

Sources:
Thorough descriptions of convolution
Visualization that explains rotation

1 Like

more helpful sources provided by @santacml
https://danieltakeshi.github.io/2019/03/09/conv-matmul/
https://towardsdatascience.com/backpropagation-in-a-convolutional-layer-24c8d64d8509

You are my hero!
I have spent the whole day just on this