Oh, in that case, neither of these solutions work:
>>> t = torch.tensor([[1, 2, 3], [4, 4, 4]])
>>> t
tensor([[1, 2, 3],
[4, 4, 4]])
>>> torch.cat(3*[t])
tensor([[1, 2, 3],
[4, 4, 4],
[1, 2, 3],
[4, 4, 4],
[1, 2, 3],
[4, 4, 4]])
>>> t.repeat(3, 1)
tensor([[1, 2, 3],
[4, 4, 4],
[1, 2, 3],
[4, 4, 4],
[1, 2, 3],
[4, 4, 4]])
But based on the answer from How to tile a tensor? we can do that so that the repeated elements are succession:
def tile(a, dim, n_tile):
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
return torch.index_select(a, dim, order_index)
then, using this function, the repeated elements will be as follows:
tile(t, 0, 3)
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3],
[4, 4, 4],
[4, 4, 4],
[4, 4, 4]])