# Inverse of tensor `torch.unfold` function

Hello,
I would need a fast way to merge the output of a `torch.Tensor.unfold`. However, I cannot find anything such that. `for` loops are too much slow for this case…

Example of what I would like to have:

``````import torch

A = torch.rand(4, 2, 7, 21)

# the following creates a view so that the 3rd dimension is
# split in windows
# * window size (kernel): 4
# * stride: 3
A = A.unfold(3, 4, 3)
# shape: torch.Size([4, 1, 7, 6, 4])

# [...] some operation on the windows

# now summing up windows by averaging overlapped
# points within windows, something like:
B = A.fold(4, 4, 3, 'mean')

``````

Could you try `torch.nn.functional.fold`?

Well, I’m not completely understanding how it works… and, by the way, it accepts only 3D input, but mine is 5D

Yeah, `fold` might not be flexible enough for your use case. However, since you are creating overlapping windows, how would you like to reduce the overlapping sections?

I think that I’ve found a workaround.

The idea is that after the unfolding you have a tensor with the windows in the last dimension:

``````In [10]: a = torch.arange(4*1*2*15).float().reshape(4, 1, 2, 15)

In [11]: a
Out[11]:
tensor([[[[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
11.,  12.,  13.,  14.],
[ 15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,  24.,  25.,
26.,  27.,  28.,  29.]]],

[[[ 30.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,
41.,  42.,  43.,  44.],
[ 45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,
56.,  57.,  58.,  59.]]],

[[[ 60.,  61.,  62.,  63.,  64.,  65.,  66.,  67.,  68.,  69.,  70.,
71.,  72.,  73.,  74.],
[ 75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,
86.,  87.,  88.,  89.]]],

[[[ 90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,  99., 100.,
101., 102., 103., 104.],
[105., 106., 107., 108., 109., 110., 111., 112., 113., 114., 115.,
116., 117., 118., 119.]]]])

In [12]: k = 5; s = 2

In [13]: a_unf = a.unfold(3, k, s)

In [14]: a_unf
Out[14]:
tensor([[[[[  0.,   1.,   2.,   3.,   4.],
[  2.,   3.,   4.,   5.,   6.],
[  4.,   5.,   6.,   7.,   8.],
[  6.,   7.,   8.,   9.,  10.],
[  8.,   9.,  10.,  11.,  12.],
[ 10.,  11.,  12.,  13.,  14.]],

[[ 15.,  16.,  17.,  18.,  19.],
[ 17.,  18.,  19.,  20.,  21.],
[ 19.,  20.,  21.,  22.,  23.],
[ 21.,  22.,  23.,  24.,  25.],
[ 23.,  24.,  25.,  26.,  27.],
[ 25.,  26.,  27.,  28.,  29.]]]],

[[[[ 30.,  31.,  32.,  33.,  34.],
[ 32.,  33.,  34.,  35.,  36.],
[ 34.,  35.,  36.,  37.,  38.],
[ 36.,  37.,  38.,  39.,  40.],
[ 38.,  39.,  40.,  41.,  42.],
[ 40.,  41.,  42.,  43.,  44.]],

[[ 45.,  46.,  47.,  48.,  49.],
[ 47.,  48.,  49.,  50.,  51.],
[ 49.,  50.,  51.,  52.,  53.],
[ 51.,  52.,  53.,  54.,  55.],
[ 53.,  54.,  55.,  56.,  57.],
[ 55.,  56.,  57.,  58.,  59.]]]],

[[[[ 60.,  61.,  62.,  63.,  64.],
[ 62.,  63.,  64.,  65.,  66.],
[ 64.,  65.,  66.,  67.,  68.],
[ 66.,  67.,  68.,  69.,  70.],
[ 68.,  69.,  70.,  71.,  72.],
[ 70.,  71.,  72.,  73.,  74.]],

[[ 75.,  76.,  77.,  78.,  79.],
[ 77.,  78.,  79.,  80.,  81.],
[ 79.,  80.,  81.,  82.,  83.],
[ 81.,  82.,  83.,  84.,  85.],
[ 83.,  84.,  85.,  86.,  87.],
[ 85.,  86.,  87.,  88.,  89.]]]],

[[[[ 90.,  91.,  92.,  93.,  94.],
[ 92.,  93.,  94.,  95.,  96.],
[ 94.,  95.,  96.,  97.,  98.],
[ 96.,  97.,  98.,  99., 100.],
[ 98.,  99., 100., 101., 102.],
[100., 101., 102., 103., 104.]],

[[105., 106., 107., 108., 109.],
[107., 108., 109., 110., 111.],
[109., 110., 111., 112., 113.],
[111., 112., 113., 114., 115.],
[113., 114., 115., 116., 117.],
[115., 116., 117., 118., 119.]]]]])

``````

Now, I want to average the overlapping elements, which in the example above are easily identifiable because they contain the same number (no processing has been made).

The idea is to sum and then divide the elements in the last `s` columns. We need a tensor to keep memory of how many additions are performed on each element. Moreover, `unfold` creates a view, so modifying entries in the last `s` columns actually modifies all the entries in all the columns, thus we need to adjust the values in the “edited” columns (or we could also copy the unfolded view to a new array if we don;t care about RAM).

``````def fix_unfolded_view(a_unf, stride):
k = a_unf.shape[-1]
s = stride
divisor = torch.ones(a_unf.shape[0], a_unf.shape[1], a_unf.shape[2], a_unf.shape[3], s)
for i in range(s):
for j in range(1, (k//s)+1):
idx = (k-1)-j*s-i
if idx < 0:
break
a_unf[:, :, :, :-j, (k-1)-i] += a_unf[:, :, :, j:, idx] / divisor[:, :, :, :-j, -i-1]
divisor[:, :, :, :-j, -i-1] += 1
a_unf[:, :, :, :, k-s:] /= divisor
``````

Finally, the elements in columns `[s:s]` of the last dimension (in the example the column starting with `2` should be fixed since they are overlapped but were not touched. For fixing them, we can recursively call the function itself, with a stride equal to `k-2*s`:

``````def fix_unfolded_view(a_unf, stride):
k = a_unf.shape[-1]
s = stride
divisor = torch.ones(a_unf.shape[0], a_unf.shape[1], a_unf.shape[2], a_unf.shape[3], s)
for i in range(s):
for j in range(1, (k//s)+1):
idx = (k-1)-j*s-i
if idx < 0:
break
a_unf[:, :, :, :-j, (k-1)-i] += a_unf[:, :, :, j:, idx] / divisor[:, :, :, :-j, -i-1]
divisor[:, :, :, :-j, -i-1] += 1
a_unf[:, :, :, :, k-s:] /= divisor
if k-2*s > 1:
del divisor
fix_unfolded_view(a_unf[:, :, :, :, :-s], k-2*s)
else:
a_unf[:, :, :, :, k-s:] /= divisor
del divisor
``````

We can then look back in the original array `a` that is now modified, wince we only used views.

The `for` loops in the worst case create a `O(k)`, but it’s less if `s > 1`; we can reasonably expect an `O(3) ~ O(5)`.

The recursive call isn’t that good, but it’s only called once or twice, usually.