How can I concatenate sub-tensors along given dimension

Suppose I have a tensor a with shape (2, 9, 4), is there an efficient way to apply torch.cat to all a[i]s like this:

torch.cat([a[0], a[1]], dim=1)

so it became a tensor with shape (9, 8)
Thanks in advance.

welcome to pytorch forums!

Check if this will work for you.

a = torch.permute((1,0,2)).view(a.shape[1], -1)

Its just permuting the dimensions and changing how we view the data.

1 Like

Thanks for replying.
Slightly changed and it works:

a = a.permute(1, 0, 2).reshape(a.shape[1], -1)

.view() doesn’t work, it asked me to “Call .contiguous() before .view()”, I guess memory copy is unavoidable.