Repeat examples along batch dimension

Borrowing from my answer, for anyone new looking for this issue, an updated function has also been introduced in pytorch - torch.repeat_interleave() to address this issue in a single operation.

So for t = torch.tensor([[1, 2, 3], [4, 4, 4]]) one can use torch.repeat_interleave(t, repeats=3, dim=0) to obtain:

tensor([[1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.],
        [4., 4., 4.],
        [4., 4., 4.],
        [4., 4., 4.]])
30 Likes