Thanks @ptrblck. Bout the fix, I think @albanD fixed this in pull request #37099.
With regards to @Joshua_Clancy’s question above. I managed to solve the non-overlapping case, i.e., unfolding an image into non-overlapping windows and then putting them back together into the original shape:
import torch
import torch.nn.functional as f
torch.manual_seed(0)
N, C, H, W = 2, 1, 9, 9
k = 3
x = torch.randn(N, C, H, W)
print(x.shape) # torch.Size([2, 1, 9, 9])
assert H % k == 0 and W % k == 0
patches = x.unfold(2, k, k).unfold(3, k, k)
print(patches.shape) # torch.Size([2, 1, 3, 3, 3, 3]) - (N, C, H//k, W//k, k, k)
patches = patches.reshape(N, C, H//k, W//k, k*k)
print(patches.shape) # torch.Size([2, 1, 3, 3, 9]) - (N, C, H//k, W//k, k*k)
patches = patches.permute(0, 1, 4, 2, 3)
print(patches.shape) # torch.Size([2, 1, 9, 3, 3]) - (N, C, k*k, H//k, W//k)
patches = patches.squeeze(1)
print(patches.shape) # torch.Size([2, 9, 3, 3]) - (N, k*k, H//k, W//k)
patches = patches.view(N, k*k, -1)
print(patches.shape) # torch.Size([2, 9, 9]) - (N, k*k, H//k * W//k)
folded = f.fold(patches, (H, W), kernel_size=k, stride=k)
print(folded.shape) # torch.Size([2, 1, 9, 9])
print(torch.eq(x, folded).all()) # True
I’m finding it harder, though, to solve the “overlapping” case, which is the one @Joshua_Clancy was actually asking about:
import torch
import torch.nn.functional as f
torch.manual_seed(0)
C, H, W = 1, 9, 9
k = 3 # kernel size
s = 1 # stride
# Let's assume that our batch has only 1 sample, for simplicity
x = torch.arange(0, H*W).view(1, C, H, W)
# im2 = torch.arange(0, H*W).view(1, C, H, W)
# x = torch.cat((im, im2), dim=0).type(torch.float32) # shape = [2, 1, 9, 9]
N = x.shape[0]
patches = x.unfold(2, k, s)
print("unfold: ", patches.shape) # [2, 1, 7, 9, 3]
patches = patches.unfold(3, k, s)
print("unfold 2: ", patches.shape) # [2, 1, 7, 7, 3, 3]
patches = patches.reshape(N, C, H-(k-1), W-(k-1), k*k)
print("reshape : ", patches.shape) # [2, 1, 7, 7, 3*3]
patches = patches.permute(0, 1, 4, 2, 3)
print("after permute: ", patches.shape) # [2, 1, 3*3, 7, 7]
patches = patches.squeeze(1)
print("squeeze: ", patches.shape) # [2, 3*3, 7, 7]
patches = patches.view(N, k*k, -1)
print("view: ", patches.shape) # [2, 3*3, 7*7]
folded = f.fold(patches.type(torch.float32), (H, W), kernel_size=k, stride=s, dilation=1)
print("folded: ", folded.shape) # [1, 1, 9, 9
If we now compare x
and folded
:
>>> x
tensor([[[[ 0, 1, 2, 3, 4, 5, 6, 7, 8],
[ 9, 10, 11, 12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23, 24, 25, 26],
[27, 28, 29, 30, 31, 32, 33, 34, 35],
[36, 37, 38, 39, 40, 41, 42, 43, 44],
[45, 46, 47, 48, 49, 50, 51, 52, 53],
[54, 55, 56, 57, 58, 59, 60, 61, 62],
[63, 64, 65, 66, 67, 68, 69, 70, 71],
[72, 73, 74, 75, 76, 77, 78, 79, 80]]]])
>>> folded.type(torch.int32)
tensor([[[[ 0, 2, 6, 9, 12, 15, 18, 14, 8],
[ 18, 40, 66, 72, 78, 84, 90, 64, 34],
[ 54, 114, 180, 189, 198, 207, 216, 150, 78],
[ 81, 168, 261, 270, 279, 288, 297, 204, 105],
[108, 222, 342, 351, 360, 369, 378, 258, 132],
[135, 276, 423, 432, 441, 450, 459, 312, 159],
[162, 330, 504, 513, 522, 531, 540, 366, 186],
[126, 256, 390, 396, 402, 408, 414, 280, 142],
[ 72, 146, 222, 225, 228, 231, 234, 158, 80]]]], dtype=torch.int32)
we can see that that folded
has summed over overlapping elements, which is the expected behavior as explained in a Note
in the docs for Fold.
So to retrieve the initial tensor, I guess we could divide each element by the number of times the filter was “passed” over it. For instance, in our 9x9 case, the centered, inner 7x7 region should be divided by 9:
folded[:, :, k-1:-(k-1), k-1:-(k-1)] = folded[:, :, k-1:-(k-1), k-1:-(k-1)] / 9
giving:
>>> x
tensor([[[[ 0, 1, 2, 3, 4, 5, 6, 7, 8],
[ 9, 10, 11, 12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23, 24, 25, 26],
[27, 28, 29, 30, 31, 32, 33, 34, 35],
[36, 37, 38, 39, 40, 41, 42, 43, 44],
[45, 46, 47, 48, 49, 50, 51, 52, 53],
[54, 55, 56, 57, 58, 59, 60, 61, 62],
[63, 64, 65, 66, 67, 68, 69, 70, 71],
[72, 73, 74, 75, 76, 77, 78, 79, 80]]]])
>>> folded
tensor([[[[ 0, 2, 6, 9, 12, 15, 18, 14, 8],
[ 18, 40, 66, 72, 78, 84, 90, 64, 34],
[ 54, 114, 20, 21, 22, 23, 24, 150, 78],
[ 81, 168, 29, 30, 31, 32, 33, 204, 105],
[108, 222, 38, 39, 40, 41, 42, 258, 132],
[135, 276, 47, 48, 49, 50, 51, 312, 159],
[162, 330, 56, 57, 58, 59, 60, 366, 186],
[126, 256, 390, 396, 402, 408, 414, 280, 142],
[ 72, 146, 222, 225, 228, 231, 234, 158, 80]]]], dtype=torch.int32)
Note, for instance, that the corner pixels ((0, 0), (0, 8), (8, 0) and (8, 8))
should be divided by 1, since they only appeared once in a filter pass. Therefore, they already correspond to the value of the initial tensor at such coordinates.
I am not sure this is a clean way to approach this “retrieval” of the initial tensor after an unfold
operation. Would it be possible that fold
didn’t “calculate each combined value in the resulting large tensor by summing all values from all containing blocks”?
But then gain, maybe there’s a reason as to why fold
should function this way and I am just missing the whole point 
Thanks!