Sliding window with repeated first element

Say I have a tensor like

test = torch.arange(16).resize(4, 4)

This should look like

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

Is there a way to implement a fixed sliding window along the rows that outputs the same number of tensors as the number of rows (by repeating the first row as many times as necessary). In other words, I want to return a tensor that looks like the following (given a window of length 3):

tensor([[[ 0,  1,  2,  3], 
         [ 0,  1,  2,  3], 
         [ 0,  1,  2,  3]],
        [[ 0,  1,  2,  3], 
         [ 0,  1,  2,  3], 
         [ 4,  5,  6,  7]],
        [[ 0,  1,  2,  3], 
         [ 4,  5,  6,  7], 
         [ 8,  9, 10, 11]],
        [[ 4,  5,  6,  7], 
         [ 8,  9, 10, 11], 
         [12, 13, 14, 15]]])

Currently, I just use regular python for loops and it works fine. However, I was wondering if this is solvable using built-in pytorch functions.

You should be able to expand the tensor and unfold it, but it depends on your actual use case if this would be faster than your loop as it also could add an additional memory overhead:

test = torch.arange(16).view(4, 4)
test = torch.cat((test[0].unsqueeze(0).expand(3, -1), test[1:]))

out = test.unfold(0, 3, 1).permute(0, 2, 1)
print(out)
# tensor([[[ 0,  1,  2,  3],
#          [ 0,  1,  2,  3],
#          [ 0,  1,  2,  3]],

#         [[ 0,  1,  2,  3],
#          [ 0,  1,  2,  3],
#          [ 4,  5,  6,  7]],

#         [[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],

#         [[ 4,  5,  6,  7],
#          [ 8,  9, 10, 11],
#          [12, 13, 14, 15]]])