Hi,
My question is this: Suppose I have a tensor a = torch.randn(3, 4, 16, 16)
, and I want to flatten along the first two dimension to make its shape to be (1, 12, 16, 16)
.
Now I can only operate like this: size=[1, -1]+list(a.size()[2:]; a = a.view(size)
which I believe is not a pytorch way to do it. How could I do it in a smarter way?
>>> a = torch.randn(3, 4, 16, 16)
>>> c = torch.cat(a.unbind()).unsqueeze(0)
>>> c.size()
torch.Size([1, 12, 16, 16])
1 Like
Thanks, is this an efficient way? Does not it involve memory copy in the process of unbind and cat?
I am not sure whether efficient.
c = a.view(1, -1, a.size(2), a.size(3))
This is an efficient way.
Thanks, but this can only support 4-d tensor, and cannot be broadcasted to tensors of other dimensions. Would you please come up with some other methods?
>>> a = torch.randn(3, 4, 16, 16, 16, 16)
>>> c = a.view(1, -1, *(a.size()[2:]))
>>> c.size()
torch.Size([1, 12, 16, 16, 16, 16])
1 Like
Thanks a lot, that is helpful!!