For a general solution working on any dimension, I implemented tile
based on the .repeat
method of torch’s tensors:
def tile(a, dim, n_tile):
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
return torch.index_select(a, dim, order_index)
Examples:
t = torch.FloatTensor([[1,2,3],[4,5,6]])
Out[54]:
tensor([[ 1., 2., 3.],
[ 4., 5., 6.]])
- Across dim 0:
tile(t,0,3)
Out[53]:
tensor([[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.],
[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]])
- Across dim 1:
tile(t,1,2)
Out[55]:
tensor([[ 1., 1., 2., 2., 3., 3.],
[ 4., 4., 5., 5., 6., 6.]])
No benchmarking performed, though