Concatenate given dimensions in a tensor

Hello,

I am trying to find an efficient solution for the following problem and it seems I cannot get it right.

I have a tensor like:

tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]],
[[13, 14, 15],
[16, 17, 18],
[19, 20, 21],
[22, 23, 24]]])

I want to turn this into:

tensor([[ 1, 2, 3, 13, 14, 15],
[ 4, 5, 6, 16, 17, 18],
[ 7, 8, 9, 19, 20, 21],
[10, 11, 12, 22, 23, 24]])

In a computationally efficient fashion because this happens intensively in my application.

Could you help me please?
Thanks!

PS: I have this solution so far but it is looks suboptimal:
torch.cat(tuple(t[:]),1)

Hi Yann!

I think that, in general, your solution is the best you can do. The
issue is that if your original tensor is contiguous, there is no way
you can construct your desired tensor as a view into your original
tensor because there is no way you can get the elements in the
order you want by “striding” through the elements of the original
tensor. So you have to make a copy of the elements of the original
tensor in order to put them in the desired order.

Having said that, it would be possible to get your desired tensor
without a copy in the unlikely event that your original tensor is
not contiguous and happens to have its elements in an order from
which you can create your desired tensor by striding.

For example:

>>> import torch
>>> torch.__version__
'1.9.0'
>>>
>>> t = torch.arange (24).reshape (2, 4, 3) + 1
>>> t
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18],
         [19, 20, 21],
         [22, 23, 24]]])
>>> t.is_contiguous()
True
>>> t_result = torch.cat (t.split (1), dim = -1).squeeze()
>>> t_result
tensor([[ 1,  2,  3, 13, 14, 15],
        [ 4,  5,  6, 16, 17, 18],
        [ 7,  8,  9, 19, 20, 21],
        [10, 11, 12, 22, 23, 24]])
>>> t_special = torch.as_strided (t_result, (2, 4, 3), (3, 6, 1))
>>> t_special
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18],
         [19, 20, 21],
         [22, 23, 24]]])
>>> t_special.is_contiguous()
False
>>> t_result_as_strided = torch.as_strided (t_special, (4, 6), (6, 1))
>>> t_result_as_strided
tensor([[ 1,  2,  3, 13, 14, 15],
        [ 4,  5,  6, 16, 17, 18],
        [ 7,  8,  9, 19, 20, 21],
        [10, 11, 12, 22, 23, 24]])

Best.

K. Frank

1 Like