Inverse of tensor `torch.unfold` function

Hello,
I would need a fast way to merge the output of a torch.Tensor.unfold. However, I cannot find anything such that. for loops are too much slow for this case…

Example of what I would like to have:

import torch

A = torch.rand(4, 2, 7, 21)

# the following creates a view so that the 3rd dimension is 
# split in windows 
# * window size (kernel): 4
# * stride: 3
A = A.unfold(3, 4, 3)
# shape: torch.Size([4, 1, 7, 6, 4])

# [...] some operation on the windows

# now summing up windows by averaging overlapped 
# points within windows, something like:
B = A.fold(4, 4, 3, 'mean')

Could you try torch.nn.functional.fold?

Well, I’m not completely understanding how it works… and, by the way, it accepts only 3D input, but mine is 5D

Yeah, fold might not be flexible enough for your use case. However, since you are creating overlapping windows, how would you like to reduce the overlapping sections?

I think that I’ve found a workaround.

The idea is that after the unfolding you have a tensor with the windows in the last dimension:

In [10]: a = torch.arange(4*1*2*15).float().reshape(4, 1, 2, 15)                    

In [11]: a                                                                                     
Out[11]: 
tensor([[[[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
            11.,  12.,  13.,  14.],
          [ 15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,  24.,  25.,
            26.,  27.,  28.,  29.]]],


        [[[ 30.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,
            41.,  42.,  43.,  44.],
          [ 45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,
            56.,  57.,  58.,  59.]]],


        [[[ 60.,  61.,  62.,  63.,  64.,  65.,  66.,  67.,  68.,  69.,  70.,
            71.,  72.,  73.,  74.],
          [ 75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,
            86.,  87.,  88.,  89.]]],


        [[[ 90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,  99., 100.,
           101., 102., 103., 104.],
          [105., 106., 107., 108., 109., 110., 111., 112., 113., 114., 115.,
           116., 117., 118., 119.]]]])


In [12]: k = 5; s = 2                                                                          

In [13]: a_unf = a.unfold(3, k, s)                                                             

In [14]: a_unf                                                                                 
Out[14]: 
tensor([[[[[  0.,   1.,   2.,   3.,   4.],
           [  2.,   3.,   4.,   5.,   6.],
           [  4.,   5.,   6.,   7.,   8.],
           [  6.,   7.,   8.,   9.,  10.],
           [  8.,   9.,  10.,  11.,  12.],
           [ 10.,  11.,  12.,  13.,  14.]],

          [[ 15.,  16.,  17.,  18.,  19.],
           [ 17.,  18.,  19.,  20.,  21.],
           [ 19.,  20.,  21.,  22.,  23.],
           [ 21.,  22.,  23.,  24.,  25.],
           [ 23.,  24.,  25.,  26.,  27.],
           [ 25.,  26.,  27.,  28.,  29.]]]],



        [[[[ 30.,  31.,  32.,  33.,  34.],
           [ 32.,  33.,  34.,  35.,  36.],
           [ 34.,  35.,  36.,  37.,  38.],
           [ 36.,  37.,  38.,  39.,  40.],
           [ 38.,  39.,  40.,  41.,  42.],
           [ 40.,  41.,  42.,  43.,  44.]],

          [[ 45.,  46.,  47.,  48.,  49.],
           [ 47.,  48.,  49.,  50.,  51.],
           [ 49.,  50.,  51.,  52.,  53.],
           [ 51.,  52.,  53.,  54.,  55.],
           [ 53.,  54.,  55.,  56.,  57.],
           [ 55.,  56.,  57.,  58.,  59.]]]],



        [[[[ 60.,  61.,  62.,  63.,  64.],
           [ 62.,  63.,  64.,  65.,  66.],
           [ 64.,  65.,  66.,  67.,  68.],
           [ 66.,  67.,  68.,  69.,  70.],
           [ 68.,  69.,  70.,  71.,  72.],
           [ 70.,  71.,  72.,  73.,  74.]],

          [[ 75.,  76.,  77.,  78.,  79.],
           [ 77.,  78.,  79.,  80.,  81.],
           [ 79.,  80.,  81.,  82.,  83.],
           [ 81.,  82.,  83.,  84.,  85.],
           [ 83.,  84.,  85.,  86.,  87.],
           [ 85.,  86.,  87.,  88.,  89.]]]],



        [[[[ 90.,  91.,  92.,  93.,  94.],
           [ 92.,  93.,  94.,  95.,  96.],
           [ 94.,  95.,  96.,  97.,  98.],
           [ 96.,  97.,  98.,  99., 100.],
           [ 98.,  99., 100., 101., 102.],
           [100., 101., 102., 103., 104.]],

          [[105., 106., 107., 108., 109.],
           [107., 108., 109., 110., 111.],
           [109., 110., 111., 112., 113.],
           [111., 112., 113., 114., 115.],
           [113., 114., 115., 116., 117.],
           [115., 116., 117., 118., 119.]]]]])


Now, I want to average the overlapping elements, which in the example above are easily identifiable because they contain the same number (no processing has been made).

The idea is to sum and then divide the elements in the last s columns. We need a tensor to keep memory of how many additions are performed on each element. Moreover, unfold creates a view, so modifying entries in the last s columns actually modifies all the entries in all the columns, thus we need to adjust the values in the “edited” columns (or we could also copy the unfolded view to a new array if we don;t care about RAM).

def fix_unfolded_view(a_unf, stride): 
     k = a_unf.shape[-1] 
     s = stride 
     divisor = torch.ones(a_unf.shape[0], a_unf.shape[1], a_unf.shape[2], a_unf.shape[3], s) 
     for i in range(s): 
         for j in range(1, (k//s)+1): 
             idx = (k-1)-j*s-i 
             if idx < 0: 
                 break 
             a_unf[:, :, :, :-j, (k-1)-i] += a_unf[:, :, :, j:, idx] / divisor[:, :, :, :-j, -i-1] 
             divisor[:, :, :, :-j, -i-1] += 1     
     a_unf[:, :, :, :, k-s:] /= divisor    

Finally, the elements in columns [s:s] of the last dimension (in the example the column starting with 2 should be fixed since they are overlapped but were not touched. For fixing them, we can recursively call the function itself, with a stride equal to k-2*s:

def fix_unfolded_view(a_unf, stride): 
     k = a_unf.shape[-1] 
     s = stride 
     divisor = torch.ones(a_unf.shape[0], a_unf.shape[1], a_unf.shape[2], a_unf.shape[3], s) 
     for i in range(s): 
         for j in range(1, (k//s)+1): 
             idx = (k-1)-j*s-i 
             if idx < 0: 
                 break 
             a_unf[:, :, :, :-j, (k-1)-i] += a_unf[:, :, :, j:, idx] / divisor[:, :, :, :-j, -i-1] 
             divisor[:, :, :, :-j, -i-1] += 1     
     a_unf[:, :, :, :, k-s:] /= divisor    
     if k-2*s > 1:  
         del divisor
         fix_unfolded_view(a_unf[:, :, :, :, :-s], k-2*s) 
     else: 
         a_unf[:, :, :, :, k-s:] /= divisor  
         del divisor

We can then look back in the original array a that is now modified, wince we only used views.

The for loops in the worst case create a O(k), but it’s less if s > 1; we can reasonably expect an O(3) ~ O(5).

The recursive call isn’t that good, but it’s only called once or twice, usually.