Transpose convolution math not working out

I was reading A Guide to Convolutional Arithmetic to understand Transpose Convolution.

From section 4.1

Using this representation, the backward pass is easily obtained by transposing C; in other words, the error is backpropagated by multiplying the loss with C.T. This operation takes a 4-dimensional vector as input and produces a 16-dimensional vector as output, and its connectivity pattern is compatible with C by construction.

When I try this out in pytorch, the error is certainly not equal to multiplying with C.T.

import torch
import torch.nn.functional as F

x = torch.arange(1, 17, dtype=torch.float).resize_(4, 4)
w = torch.rand(3, 3)

Convolve w and x

# Convert x into an "image" tensor with a single channel and part of a mini-batch of size 1
x1 = x.view(1, 1, 4, 4)
x1.requires_grad = True

# Convert w into a conv filter with a single input channel and a single output channel
w1 = w.view(1, 1, 3, 3)
w1.requires_grad = True

y1 = F.conv2d(x1, w1)

Backpropagate

y1.backward(torch.ones_like(y1))
x1.grad

Now create the C matrix as mentioned in the paper.

C = torch.zeros(4, 16, dtype=torch.float)

C[0, :3] = w[0]
C[0, 4:7] = w[1]
C[0, 8:11] = w[2]

C[1, 1:4] = w[0]
C[1, 5:8] = w[1]
C[1, 9:12] = w[2]

C[2, 4:7] = w[0]
C[2, 8:11] = w[1]
C[2, 12:15] = w[2]

C[3, 5:8] = w[0]
C[3, 9:12] = w[1]
C[3, 13:] = w[2]

Multiplying unrolled y1 by C.T will not equal to x1.grad.

torch.mm(C.transpose(0, 1), y1.view(-1, 1)).view(4, 4)

What am I doing wrong?

This needs to be

torch.mm(C.transpose(0, 1), torch.ones_like(y1).view(-1, 1)).view(4, 4)

When you want have a product that you want to backpropagate through, you replace the factor w.r.t. which you want to differentiate by the (appropriately expanded and summed up) output gradient, not by the output itself.

My favourite way of looking at derivatives of (multi)linear functions is in terms of Einstein summation notation, let’s take torch.nn.functional.bilinear as an example: It can be written for (inefficient) einsum as out = torch.einsum('bi,kij,bj->bk', left, weight, right). Now as torch.expand and torch.sum are dual to each other for taking derivatives and by the product differentiation rule, you have that
weight.grad = torch.einsum('bi,bk,bj->kij', left, out.grad, right) etc., so you swap in grad.out for one of the arguments and make the exchange between the right hand side and that factor in the equation, and you get the right gradient formula.

But now I got carried away, …

Best regards

Thomas