How to collapse a serie of 2d convolutions


I would like to collapse a serie of 2d convolutions to a single conv operation.
Basically I’d like to make the pytorch convolution associative.

For example if I have something like
y = F.conv2d(F.conv2d(a1,k1),k0)
I’d like to write it this way: y = F.conv2d(a1,K) (where K depends on k1 and k2).

If my calculus are not wrong, this should work:

a1 = torch.randn(1,100,256,256)
k1 = torch.randn(25,100,5,5)
k0 = torch.rand(1,25,5,5)

a0 = F.conv2d(a1,k1)
y = F.conv2d(a0,k0)

K = F.conv2d(k1.transpose(0,1),k0,padding=4).transpose(0,1)
y2 = F.conv2d(a1,K)

But it doesn’t work.
Does anyone know what I am forgetting ? Thanks

The convolution operator in PyTorch is a cross-correlation and not a convolution in the signal processing sense.
From the docs:

[…] where ⋆\star⋆ is the valid 2D cross-correlation operator, N is a batch size, C denotes a number of channels, H is a height of input planes in pixels, and W is width in pixels.

While a convolution is commutative, a cross-correlation is not.
Thus, you would need to flip the kernels in the spatial dimensions to get approx. the same result:

a1 = torch.randn(1,100,256,256)
k1 = torch.randn(25,100,5,5)
k0 = torch.randn(1,25,5,5)

a0 = F.conv2d(a1,torch.flip(k1, [2, 3]))
y = F.conv2d(a0,torch.flip(k0, [2, 3]))

K = F.conv2d(k1.transpose(0, 1),torch.flip(k0, [2, 3]),padding=4).transpose(0, 1)
y2 = F.conv2d(a1,torch.flip(K, [2, 3]))

print((y - y2).abs().max())
> tensor(0.0052)

The max abs error seems to be a bit high, but if we lower the number of channels, we can reduce this error to approx. ~1e-5, so I assume it might be due to the limited floating point precision (or I’m not seeing the bug in the code).