Repeat examples along batch dimension

Oh, in that case, neither of these solutions work:

>>> t = torch.tensor([[1, 2, 3], [4, 4, 4]])                                                        
>>> t
tensor([[1, 2, 3],
        [4, 4, 4]])
>>> torch.cat(3*[t])
tensor([[1, 2, 3],
        [4, 4, 4],
        [1, 2, 3],
        [4, 4, 4],
        [1, 2, 3],
        [4, 4, 4]])
>>> t.repeat(3, 1)
tensor([[1, 2, 3],
        [4, 4, 4],
        [1, 2, 3],
        [4, 4, 4],
        [1, 2, 3],
        [4, 4, 4]])

But based on the answer from How to tile a tensor? we can do that so that the repeated elements are succession:

def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(a, dim, order_index)

then, using this function, the repeated elements will be as follows:

tile(t, 0, 3)
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3],
        [4, 4, 4],
        [4, 4, 4],
        [4, 4, 4]])
6 Likes