Need help understanding Conv2d and fold, unfold

I’m working on a cnn that directly processes each patch. After reading the documentation on fold and unfold, my understanding is that I can first apply convolution on an arbitrary [b, c, h, w] input named A, with some parameters for stride, dilation and padding. Let’s notate the output as shape [b, c, h1, w1], named B.

My understanding for how fold and unfold works is as follows: If I were to unfold the input A, I would get something of the shape [b, H, L]. Then I can apply some transformation to H dimension, and then use unfold with B’s shape for the output_shape parameter, and taking the same parameters like stride and dilation.

However, it doesn’t look like it works this way. If my understanding is correct, the following example should execute without a problem. But it doesn’t run. Can someone point out what I missed?

A = torch.randn(4, 32, 224, 224)
out_shape = nn.Conv2d(32, 32, kernel_size=3, dilation=2, stride=2, padding=0)(A).shape[-2:]
windows = torch.nn.functional.unfold(A, kernel_size=3, dilation=2, stride=2, padding=0)
A_ = torch.nn.functional.fold(windows, output_size=out_shape, kernel_size=3, dilation=2, stride=2, padding=0)

This would complain that fold given output size=110 would require a size of L=2809, while the windows variable has an L=12100 which is 110**2. Then I tried to fold again with output_size=224, which doesn’t complain anymore but gives me back a shape of [4, 32, 224, 224].

Take a look at this documentation for the non-functional version:

https://pytorch.org/docs/stable/generated/torch.nn.Fold.html#torch.nn.Fold

In particular to:

fold(unfold(input)) == divisor * input

I believe Fold would restore the original input shape rather than yield the output shape of the corresponding convolution.

I see, now that you mention it the doc does mention that it restores the original size. Do you know of a way to preserve the behavior of parameters such as dilation and stride, so that the effect of unfold and fold is almost identical to that of a Conv2d module?

I would think about your problem from the perspective that some sort of reduction needs to happen in order to transform the Unfold output to the equivalent of a conv. This example may not be the exact operation you are trying to implement but it illustrates some basics:

import torch
from torch import nn

A = torch.ones(4, 32, 224, 224)
conv = nn.Conv2d(32, 1, kernel_size=3, dilation=2, stride=2,padding=0, bias=False)
conv.weight = torch.nn.Parameter(torch.ones_like(conv.weight))
out = conv(A)
print(out.shape)
windows = torch.nn.functional.unfold(A, kernel_size=3, dilation=2, stride=2, padding=0)
print(windows.shape)
windows = windows.reshape(4, 1, 32*9, out.shape[-2], out.shape[-1])
windows_reduced = torch.sum(windows, axis=2)
print(windows_reduced.shape)
print(torch.allclose(out, windows_reduced))
torch.Size([4, 1, 110, 110])
torch.Size([4, 288, 12100])
torch.Size([4, 1, 110, 110])
True

Note that unfold will produce an “output” equivalent to that of a convolution with just a single channel.

Thank you for the detailed reply. I do see this solution working and I’ll go with something along this line.