Sliding window of 2D tensor with rollover

Hi guys, I’m trying to figure out how to create a sliding window on a 2D tensor with some form of rollover when it hits the boundaries. For example, suppose we have the following tensor:

tensor([[2, 4, 6, 5, 9, 7, 3, 8, 1, 0],
        [0, 2, 1, 9, 4, 5, 6, 8, 3, 7],
        [2, 5, 8, 9, 6, 3, 0, 1, 4, 7],
        [5, 2, 1, 6, 8, 4, 9, 0, 3, 7]])

where each row is a batch of indexes. What I want to achieve is a sliding window of 5 elements, where the centre element is the focal point and the 2 left and 2 right neighbors form the 5 elements. For example, if we take the first row [2, 4, 6, 5, 9, 7, 3, 8, 1, 0], the end result should be:

[1, 0, 2, 4, 6]
[0, 2, 4, 6, 5]
[2, 4, 6, 5, 9]
[4, 6, 5, 9, 7]
[6, 5, 9, 7, 3]
[5, 9, 7, 3, 8]
[9, 7, 3, 8, 1]
[7, 3, 8, 1, 0]
[3, 8, 1, 0, 2]
[8, 1, 0, 2, 4]

Tensor unfold doesn’t seem quite to get the result of this.

Manual padding and unfold should work:

x = torch.tensor([[2, 4, 6, 5, 9, 7, 3, 8, 1, 0],
                  [0, 2, 1, 9, 4, 5, 6, 8, 3, 7],
                  [2, 5, 8, 9, 6, 3, 0, 1, 4, 7],
                  [5, 2, 1, 6, 8, 4, 9, 0, 3, 7]])

# pad borders
x_pad = torch.cat((x[:, -2:], x, x[:, :2]), dim=1)
res = x_pad.unfold(1, 5, 1)
print(res)
# tensor([[[1, 0, 2, 4, 6],
#          [0, 2, 4, 6, 5],
#          [2, 4, 6, 5, 9],
#          [4, 6, 5, 9, 7],
#          [6, 5, 9, 7, 3],
#          [5, 9, 7, 3, 8],
#          [9, 7, 3, 8, 1],
#          [7, 3, 8, 1, 0],
#          [3, 8, 1, 0, 2],
#          [8, 1, 0, 2, 4]],

#         [[3, 7, 0, 2, 1],
#          [7, 0, 2, 1, 9],
#          [0, 2, 1, 9, 4],
#          [2, 1, 9, 4, 5],
#          [1, 9, 4, 5, 6],
#          [9, 4, 5, 6, 8],
#          [4, 5, 6, 8, 3],
#          [5, 6, 8, 3, 7],
#          [6, 8, 3, 7, 0],
#          [8, 3, 7, 0, 2]],
# ...
1 Like