Fold(unfold(x)) != x

In[1]: x = torch.rand(32, 3, 16, 16)
In[2]: torch.equal(F.fold(F.unfold(x, 3), x.shape[2], 3), x)
Out[2]: False

Shouldn’t this return True? If not, How do I inverse this operation?
I tried to run this on MNIST images, and visually I also see a difference at the edges of the digits.

1 Like

any updates on this question?

F.fold sums overlapping blocks, which is the case here.
If you don’t have overlapping windows, it’ll work:

x = torch.rand(32, 3, 16, 16)
torch.equal(F.fold(F.unfold(x, 2, stride=2), x.shape[2], 2, stride=2), x)
1 Like