If I have a 3 x 768 tensor [a, b, c], how can I repeat it such that I get [[a,b,c],[a,b,c],[a,b,c]], instead of [[a,a,a],[b,b,b],[c,c,c]], which is the case when I use torch.repeat_interleave? thanks!
I think you can do it this way:
n = 3
m = 4 # 768 (I replace just to be able to visualize the result)
a = torch.tensor([1]*m)
b = torch.tensor([2]*m)
c = torch.tensor([3]*m)
abc = torch.stack([a, b, c], dim=0)
print(abc)
"""
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
"""
print(abc.repeat(3, 1, 1))
"""
tensor([[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]],
[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]],
[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]]])
"""
# or
print(torch.repeat_interleave(abc.unsqueeze(0), repeats=3, dim=0))
"""
tensor([[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]],
[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]],
[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]]])
"""