# How could I flatten two dimensions of a tensor

Hi,

My question is this: Suppose I have a tensor `a = torch.randn(3, 4, 16, 16)`, and I want to flatten along the first two dimension to make its shape to be `(1, 12, 16, 16)`.

Now I can only operate like this: `size=[1, -1]+list(a.size()[2:]; a = a.view(size)` which I believe is not a pytorch way to do it. How could I do it in a smarter way?

``````>>> a = torch.randn(3, 4, 16, 16)
>>> c = torch.cat(a.unbind()).unsqueeze(0)
>>> c.size()
torch.Size([1, 12, 16, 16])
``````
1 Like

Thanks, is this an efficient way? Does not it involve memory copy in the process of unbind and cat?

I am not sure whether efficient.

``````c = a.view(1, -1, a.size(2), a.size(3))
``````

This is an efficient way.

Thanks, but this can only support 4-d tensor, and cannot be broadcasted to tensors of other dimensions. Would you please come up with some other methods?

``````>>> a = torch.randn(3, 4, 16, 16, 16, 16)
>>> c = a.view(1, -1, *(a.size()[2:]))
>>> c.size()
torch.Size([1, 12, 16, 16, 16, 16])
``````
1 Like

Thanks a lot, that is helpful!!