I figured it out:
h, w = 4, 10
repeats = 2
my_tensor = torch.randint(0, 10, [h, w])
print("my_tensor:", my_tensor)
repeated_1 = my_tensor.repeat_interleave(repeats , -2)
print("repeated_1:", repeated_1)
repeated_2 = my_tensor.view(h, 1, w).expand(h, repeats, w).contiguous().view(-1, w)
print("repeated_2:", repeated_2)
my_tensor:
tensor([[7, 6, 8, 5, 5, 2, 2, 1, 2, 8],
[4, 0, 3, 2, 8, 4, 5, 9, 3, 1],
[3, 3, 3, 1, 3, 3, 2, 4, 6, 5],
[0, 4, 8, 7, 7, 7, 3, 4, 6, 9]])
repeated_1:
tensor([[7, 6, 8, 5, 5, 2, 2, 1, 2, 8],
[7, 6, 8, 5, 5, 2, 2, 1, 2, 8],
[4, 0, 3, 2, 8, 4, 5, 9, 3, 1],
[4, 0, 3, 2, 8, 4, 5, 9, 3, 1],
[3, 3, 3, 1, 3, 3, 2, 4, 6, 5],
[3, 3, 3, 1, 3, 3, 2, 4, 6, 5],
[0, 4, 8, 7, 7, 7, 3, 4, 6, 9],
[0, 4, 8, 7, 7, 7, 3, 4, 6, 9]])
repeated_2:
tensor([[7, 6, 8, 5, 5, 2, 2, 1, 2, 8],
[7, 6, 8, 5, 5, 2, 2, 1, 2, 8],
[4, 0, 3, 2, 8, 4, 5, 9, 3, 1],
[4, 0, 3, 2, 8, 4, 5, 9, 3, 1],
[3, 3, 3, 1, 3, 3, 2, 4, 6, 5],
[3, 3, 3, 1, 3, 3, 2, 4, 6, 5],
[0, 4, 8, 7, 7, 7, 3, 4, 6, 9],
[0, 4, 8, 7, 7, 7, 3, 4, 6, 9]])
I had to use contiguous()
. Having not used it before, I expected the time to be similar to just using repeat_interleave()
.
And… it is weird… timing these two operations gives me similar times but in my training script, the latter is multiple times faster… maybe due to torch.backends.cudnn.benchmark = True
?
Anyway, I am happy with the results