Suppose I have a tensor a
with shape (2, 9, 4), is there an efficient way to apply torch.cat
to all a[i]
s like this:
torch.cat([a[0], a[1]], dim=1)
so it became a tensor with shape (9, 8)
Thanks in advance.
Suppose I have a tensor a
with shape (2, 9, 4), is there an efficient way to apply torch.cat
to all a[i]
s like this:
torch.cat([a[0], a[1]], dim=1)
so it became a tensor with shape (9, 8)
Thanks in advance.
welcome to pytorch forums!
Check if this will work for you.
a = torch.permute((1,0,2)).view(a.shape[1], -1)
Its just permuting the dimensions and changing how we view the data.
Thanks for replying.
Slightly changed and it works:
a = a.permute(1, 0, 2).reshape(a.shape[1], -1)
.view() doesn’t work, it asked me to “Call .contiguous() before .view()”, I guess memory copy is unavoidable.