How to tile a tensor?

For a general solution working on any dimension, I implemented tile based on the .repeat method of torch’s tensors:

def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(a, dim, order_index)

Examples:

t = torch.FloatTensor([[1,2,3],[4,5,6]])
Out[54]: 
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.]])
  • Across dim 0:
tile(t,0,3)
Out[53]: 
tensor([[ 1.,  2.,  3.],
        [ 1.,  2.,  3.],
        [ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 4.,  5.,  6.],
        [ 4.,  5.,  6.]])
  • Across dim 1:
tile(t,1,2)
Out[55]: 
tensor([[ 1.,  1.,  2.,  2.,  3.,  3.],
        [ 4.,  4.,  5.,  5.,  6.,  6.]])

No benchmarking performed, though :slight_smile:

13 Likes