I think that I’ve found a workaround.
The idea is that after the unfolding you have a tensor with the windows in the last dimension:
In [10]: a = torch.arange(4*1*2*15).float().reshape(4, 1, 2, 15)
In [11]: a
Out[11]:
tensor([[[[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.,
11., 12., 13., 14.],
[ 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
26., 27., 28., 29.]]],
[[[ 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40.,
41., 42., 43., 44.],
[ 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
56., 57., 58., 59.]]],
[[[ 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70.,
71., 72., 73., 74.],
[ 75., 76., 77., 78., 79., 80., 81., 82., 83., 84., 85.,
86., 87., 88., 89.]]],
[[[ 90., 91., 92., 93., 94., 95., 96., 97., 98., 99., 100.,
101., 102., 103., 104.],
[105., 106., 107., 108., 109., 110., 111., 112., 113., 114., 115.,
116., 117., 118., 119.]]]])
In [12]: k = 5; s = 2
In [13]: a_unf = a.unfold(3, k, s)
In [14]: a_unf
Out[14]:
tensor([[[[[ 0., 1., 2., 3., 4.],
[ 2., 3., 4., 5., 6.],
[ 4., 5., 6., 7., 8.],
[ 6., 7., 8., 9., 10.],
[ 8., 9., 10., 11., 12.],
[ 10., 11., 12., 13., 14.]],
[[ 15., 16., 17., 18., 19.],
[ 17., 18., 19., 20., 21.],
[ 19., 20., 21., 22., 23.],
[ 21., 22., 23., 24., 25.],
[ 23., 24., 25., 26., 27.],
[ 25., 26., 27., 28., 29.]]]],
[[[[ 30., 31., 32., 33., 34.],
[ 32., 33., 34., 35., 36.],
[ 34., 35., 36., 37., 38.],
[ 36., 37., 38., 39., 40.],
[ 38., 39., 40., 41., 42.],
[ 40., 41., 42., 43., 44.]],
[[ 45., 46., 47., 48., 49.],
[ 47., 48., 49., 50., 51.],
[ 49., 50., 51., 52., 53.],
[ 51., 52., 53., 54., 55.],
[ 53., 54., 55., 56., 57.],
[ 55., 56., 57., 58., 59.]]]],
[[[[ 60., 61., 62., 63., 64.],
[ 62., 63., 64., 65., 66.],
[ 64., 65., 66., 67., 68.],
[ 66., 67., 68., 69., 70.],
[ 68., 69., 70., 71., 72.],
[ 70., 71., 72., 73., 74.]],
[[ 75., 76., 77., 78., 79.],
[ 77., 78., 79., 80., 81.],
[ 79., 80., 81., 82., 83.],
[ 81., 82., 83., 84., 85.],
[ 83., 84., 85., 86., 87.],
[ 85., 86., 87., 88., 89.]]]],
[[[[ 90., 91., 92., 93., 94.],
[ 92., 93., 94., 95., 96.],
[ 94., 95., 96., 97., 98.],
[ 96., 97., 98., 99., 100.],
[ 98., 99., 100., 101., 102.],
[100., 101., 102., 103., 104.]],
[[105., 106., 107., 108., 109.],
[107., 108., 109., 110., 111.],
[109., 110., 111., 112., 113.],
[111., 112., 113., 114., 115.],
[113., 114., 115., 116., 117.],
[115., 116., 117., 118., 119.]]]]])
Now, I want to average the overlapping elements, which in the example above are easily identifiable because they contain the same number (no processing has been made).
The idea is to sum and then divide the elements in the last s
columns. We need a tensor to keep memory of how many additions are performed on each element. Moreover, unfold
creates a view, so modifying entries in the last s
columns actually modifies all the entries in all the columns, thus we need to adjust the values in the “edited” columns (or we could also copy the unfolded view to a new array if we don;t care about RAM).
def fix_unfolded_view(a_unf, stride):
k = a_unf.shape[-1]
s = stride
divisor = torch.ones(a_unf.shape[0], a_unf.shape[1], a_unf.shape[2], a_unf.shape[3], s)
for i in range(s):
for j in range(1, (k//s)+1):
idx = (k-1)-j*s-i
if idx < 0:
break
a_unf[:, :, :, :-j, (k-1)-i] += a_unf[:, :, :, j:, idx] / divisor[:, :, :, :-j, -i-1]
divisor[:, :, :, :-j, -i-1] += 1
a_unf[:, :, :, :, k-s:] /= divisor
Finally, the elements in columns [s:s]
of the last dimension (in the example the column starting with 2
should be fixed since they are overlapped but were not touched. For fixing them, we can recursively call the function itself, with a stride equal to k-2*s
:
def fix_unfolded_view(a_unf, stride):
k = a_unf.shape[-1]
s = stride
divisor = torch.ones(a_unf.shape[0], a_unf.shape[1], a_unf.shape[2], a_unf.shape[3], s)
for i in range(s):
for j in range(1, (k//s)+1):
idx = (k-1)-j*s-i
if idx < 0:
break
a_unf[:, :, :, :-j, (k-1)-i] += a_unf[:, :, :, j:, idx] / divisor[:, :, :, :-j, -i-1]
divisor[:, :, :, :-j, -i-1] += 1
a_unf[:, :, :, :, k-s:] /= divisor
if k-2*s > 1:
del divisor
fix_unfolded_view(a_unf[:, :, :, :, :-s], k-2*s)
else:
a_unf[:, :, :, :, k-s:] /= divisor
del divisor
We can then look back in the original array a
that is now modified, wince we only used views.
The for
loops in the worst case create a O(k)
, but it’s less if s > 1
; we can reasonably expect an O(3) ~ O(5)
.
The recursive call isn’t that good, but it’s only called once or twice, usually.