How to Unfold a tensor into sliding windows and then fold it back

Hello!

So I have been trying to unfold an image tensor up into multiple sliding windows and then fold it back into an image. I have found multiple threads about this but none that have solved my problem. Right now I have successfully unfolded my images up into sliding windows like this:

im = torch.arange(0, 81).view(1,1,9,9)
im2 = torch.arange(0, 81).view(1,1,9,9)
x = torch.cat((im, im2), dim=0) # shape = [2,1,9,9]

patches = x.unfold(2,3,1)
print("unfold: ", patches.shape) # [2,1,7,9,3]

patches = patches.unfold(3,3,1)
print("unfold 2: ", patches.shape) # [2, 1, 7, 7, 3, 3]

patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
print("after permute: ", patches.shape) #[2, 7, 7, 1, 3, 3]

patches = patches.view(patches.size()[0], -1, patches.size()[4], patches.size()[5])
print(patches.shape) # [2, 49, 3, 3]

At this point I have a bunch of sliding windows that I can work with… but before building a network around them I decided to try and make sure that I could fold them back. I have been working on this for a while now and seem to be stuck.

I have tried using fold in these ways:

fold = nn.Fold(output_size = (9,9), kernel_size = (3,3))
together = fold(patches)
print(together.shape)

fold = nn.Fold(output_size = 9, kernel_size = 3)
together = fold(patches)
print(together.shape)

But I keep running into problems. One big concern is that the documentation for fold seems wrong to me. It says: … warning::
Currently, only 4-D output tensors (batched image-like tensors) are
supported.
But when I try and fold 4d tensors I get this error:

Input Error: Only 3D input Tensors are supported (got 4D)

Any help would be much appreciated! Thanks.

I confirm that I’ve seen the exact same issue.

1 Like

I’m experiencing similar issues. (pytorch version 1.4.0)

EDIT: by looking at the source code https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py, I can clearly see a mismatch between documentation and implementation:

def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
    # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor  # noqa
    r"""Combines an array of sliding local blocks into a large containing
    tensor.
    .. warning::
        Currently, only 4-D output tensors (batched image-like tensors) are
        supported.
    See :class:`torch.nn.Fold` for details
    """
    if not torch.jit.is_scripting():
        if type(input) is not Tensor and has_torch_function((input,)):
            return handle_torch_function(
                fold, (input,), input, output_size, kernel_size, dilation=dilation,
                padding=padding, stride=stride)
    if input.dim() == 3:
        msg = '{} must be int or 2-tuple for 3D input'
        assert_int_or_pair(output_size, 'output_size', msg)
        assert_int_or_pair(kernel_size, 'kernel_size', msg)
        assert_int_or_pair(dilation, 'dilation', msg)
        assert_int_or_pair(padding, 'padding', msg)
        assert_int_or_pair(stride, 'stride', msg)

        return torch._C._nn.col2im(input, _pair(output_size), _pair(kernel_size),
                                   _pair(dilation), _pair(padding), _pair(stride))
    else:
        raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim()))
1 Like

The inline docs seem to be misleading. Could you please create an issue so that we can improve it, please?

Thanks @ptrblck, I created the issue #37063. Let me know if it should have been better to issue it as a bug report, instead of as a doc issue.

I think the method does the “right” thing as far as I understand the method and think it’s a documentation issue.

Thanks for creating the issue!
Would you be interested in creating a fix for it?

1 Like

Thanks @ptrblck. Bout the fix, I think @albanD fixed this in pull request #37099.

With regards to @Joshua_Clancy’s question above. I managed to solve the non-overlapping case, i.e., unfolding an image into non-overlapping windows and then putting them back together into the original shape:

import torch
import torch.nn.functional as f
torch.manual_seed(0)

N, C, H, W = 2, 1, 9, 9
k = 3

x = torch.randn(N, C, H, W)
print(x.shape) # torch.Size([2, 1, 9, 9])

assert H % k == 0 and W % k == 0

patches = x.unfold(2, k, k).unfold(3, k, k)         
print(patches.shape) # torch.Size([2, 1, 3, 3, 3, 3]) - (N, C, H//k, W//k, k, k) 

patches = patches.reshape(N, C, H//k, W//k, k*k)
print(patches.shape) # torch.Size([2, 1, 3, 3, 9])    - (N, C, H//k, W//k, k*k)

patches = patches.permute(0, 1, 4, 2, 3)
print(patches.shape) # torch.Size([2, 1, 9, 3, 3])    - (N, C, k*k, H//k, W//k)

patches = patches.squeeze(1)
print(patches.shape) # torch.Size([2, 9, 3, 3])       - (N, k*k, H//k, W//k)

patches = patches.view(N, k*k, -1)
print(patches.shape) # torch.Size([2, 9, 9])          - (N, k*k, H//k * W//k)

folded = f.fold(patches, (H, W), kernel_size=k, stride=k)
print(folded.shape)  # torch.Size([2, 1, 9, 9])

print(torch.eq(x, folded).all()) # True

I’m finding it harder, though, to solve the “overlapping” case, which is the one @Joshua_Clancy was actually asking about:

import torch
import torch.nn.functional as f
torch.manual_seed(0)

C, H, W = 1, 9, 9
k = 3 # kernel size
s = 1 # stride

# Let's assume that our batch has only 1 sample, for simplicity
x  = torch.arange(0, H*W).view(1, C, H, W)
# im2 = torch.arange(0, H*W).view(1, C, H, W)
# x = torch.cat((im, im2), dim=0).type(torch.float32) # shape = [2, 1, 9, 9]

N = x.shape[0]

patches = x.unfold(2, k, s)
print("unfold:        ", patches.shape) # [2, 1, 7, 9, 3]

patches = patches.unfold(3, k, s)
print("unfold 2:      ", patches.shape) # [2, 1, 7, 7, 3, 3]

patches = patches.reshape(N, C, H-(k-1), W-(k-1), k*k)
print("reshape :      ", patches.shape) # [2, 1, 7, 7, 3*3]

patches = patches.permute(0, 1, 4, 2, 3)
print("after permute: ", patches.shape) # [2, 1, 3*3, 7, 7]

patches = patches.squeeze(1)
print("squeeze:       ", patches.shape) # [2, 3*3, 7, 7]

patches = patches.view(N, k*k, -1)
print("view:          ", patches.shape) # [2, 3*3, 7*7]

folded = f.fold(patches.type(torch.float32), (H, W), kernel_size=k, stride=s, dilation=1)
print("folded:          ", folded.shape) # [1, 1, 9, 9

If we now compare x and folded:

>>> x
tensor([[[[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
          [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
          [18, 19, 20, 21, 22, 23, 24, 25, 26],
          [27, 28, 29, 30, 31, 32, 33, 34, 35],
          [36, 37, 38, 39, 40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49, 50, 51, 52, 53],
          [54, 55, 56, 57, 58, 59, 60, 61, 62],
          [63, 64, 65, 66, 67, 68, 69, 70, 71],
          [72, 73, 74, 75, 76, 77, 78, 79, 80]]]])

>>> folded.type(torch.int32)
tensor([[[[  0,   2,   6,   9,  12,  15,  18,  14,   8],
          [ 18,  40,  66,  72,  78,  84,  90,  64,  34],
          [ 54, 114, 180, 189, 198, 207, 216, 150,  78],
          [ 81, 168, 261, 270, 279, 288, 297, 204, 105],
          [108, 222, 342, 351, 360, 369, 378, 258, 132],
          [135, 276, 423, 432, 441, 450, 459, 312, 159],
          [162, 330, 504, 513, 522, 531, 540, 366, 186],
          [126, 256, 390, 396, 402, 408, 414, 280, 142],
          [ 72, 146, 222, 225, 228, 231, 234, 158,  80]]]], dtype=torch.int32)

we can see that that folded has summed over overlapping elements, which is the expected behavior as explained in a Note in the docs for Fold.

So to retrieve the initial tensor, I guess we could divide each element by the number of times the filter was “passed” over it. For instance, in our 9x9 case, the centered, inner 7x7 region should be divided by 9:

folded[:, :, k-1:-(k-1), k-1:-(k-1)] = folded[:, :, k-1:-(k-1), k-1:-(k-1)] / 9

giving:

>>> x
tensor([[[[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
          [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
          [18, 19, 20, 21, 22, 23, 24, 25, 26],
          [27, 28, 29, 30, 31, 32, 33, 34, 35],
          [36, 37, 38, 39, 40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49, 50, 51, 52, 53],
          [54, 55, 56, 57, 58, 59, 60, 61, 62],
          [63, 64, 65, 66, 67, 68, 69, 70, 71],
          [72, 73, 74, 75, 76, 77, 78, 79, 80]]]])
>>> folded
tensor([[[[  0,   2,   6,   9,  12,  15,  18,  14,   8],
          [ 18,  40,  66,  72,  78,  84,  90,  64,  34],
          [ 54, 114,  20,  21,  22,  23,  24, 150,  78],
          [ 81, 168,  29,  30,  31,  32,  33, 204, 105],
          [108, 222,  38,  39,  40,  41,  42, 258, 132],
          [135, 276,  47,  48,  49,  50,  51, 312, 159],
          [162, 330,  56,  57,  58,  59,  60, 366, 186],
          [126, 256, 390, 396, 402, 408, 414, 280, 142],
          [ 72, 146, 222, 225, 228, 231, 234, 158,  80]]]], dtype=torch.int32)

Note, for instance, that the corner pixels ((0, 0), (0, 8), (8, 0) and (8, 8)) should be divided by 1, since they only appeared once in a filter pass. Therefore, they already correspond to the value of the initial tensor at such coordinates.

I am not sure this is a clean way to approach this “retrieval” of the initial tensor after an unfold operation. Would it be possible that fold didn’t “calculate each combined value in the resulting large tensor by summing all values from all containing blocks”?

But then gain, maybe there’s a reason as to why fold should function this way and I am just missing the whole point :slight_smile:

Thanks!

2 Likes