How could I flatten two dimensions of a tensor

Hi,

My question is this: Suppose I have a tensor a = torch.randn(3, 4, 16, 16), and I want to flatten along the first two dimension to make its shape to be (1, 12, 16, 16).

Now I can only operate like this: size=[1, -1]+list(a.size()[2:]; a = a.view(size) which I believe is not a pytorch way to do it. How could I do it in a smarter way?

>>> a = torch.randn(3, 4, 16, 16)
>>> c = torch.cat(a.unbind()).unsqueeze(0)
>>> c.size()
torch.Size([1, 12, 16, 16])
1 Like

Thanks, is this an efficient way? Does not it involve memory copy in the process of unbind and cat?

I am not sure whether efficient.

c = a.view(1, -1, a.size(2), a.size(3))

This is an efficient way.

Thanks, but this can only support 4-d tensor, and cannot be broadcasted to tensors of other dimensions. Would you please come up with some other methods?

>>> a = torch.randn(3, 4, 16, 16, 16, 16)
>>> c = a.view(1, -1, *(a.size()[2:]))
>>> c.size()
torch.Size([1, 12, 16, 16, 16, 16])
1 Like

Thanks a lot, that is helpful!!