F.fold
sums overlapping blocks, which is the case here.
If you don’t have overlapping windows, it’ll work:
x = torch.rand(32, 3, 16, 16)
torch.equal(F.fold(F.unfold(x, 2, stride=2), x.shape[2], 2, stride=2), x)
F.fold
sums overlapping blocks, which is the case here.
If you don’t have overlapping windows, it’ll work:
x = torch.rand(32, 3, 16, 16)
torch.equal(F.fold(F.unfold(x, 2, stride=2), x.shape[2], 2, stride=2), x)