Hi, I am trying to see if there is an equivalent to the np.reshape() command in pyTorch? Inspecting the doc an googling around I cant seem to find anything that does this, except in the “old” torch manual? Does pyTorch support this for torch.Tensors, if so, is there a reason I cant find it? How can I reshape a torch.Tensor ? Thanks!
I think the
view method would be the right tool for that. E.g.,
>>> import torch >>> t = torch.ones((2, 3, 4)) >>> t.size() torch.Size([2, 3, 4]) >>> t.view(-1, 12).size() torch.Size([2, 12])
Interesting - thanks @rasbt! I will take a look at it. From your knowledge, is it more or less similar to reshape? Any caveats to be aware of? thanks.
A caveat is that it has to be contiguous, but that matches numpy as far as I know.
Is there a way to do this in a
Since people have asked about this several times – yes. You just have to define the module version yourself:
class View(nn.Module): def __init__(self, *shape): super(View, self).__init__() self.shape = shape def forward(self, input): return input.view(*shape) sequential_model = nn.Sequential([Linear(10, 20), View(-1, 5, 4)])
One more thing about
np.reshape function in numpy is that you can specify its order of reshaping. But I think in
torch.view() there is not such a good feature.
You can use permute in pytorch to specify its order of reshaping.
t = torch.rand((2, 3, 4)) t = t.permute(1, 0, 2)
this can reshape its order
Do you think it a good idea to add a
reshape function as a copy of
view in pytorch to accommodate heavy numpy users?
Yes, I’d like to add a
why is this not called reshape?
My guess is that it’s been done to be consistent with torch rather than numpy, which makes sense. However, yeah, some of the naming conventions are a bit annoying for people coming from NumPy, not Torch
I think the torchTensor.view() is same as to np.reshape() after my experiment. And torchTesnor.permute is same as to np.transpose. But I have a question about what is the use of np.reshape. Because in image processing domain, reshape will change the structure information of the image and that is fatal.
In image processing field, we should use permute in my opinion then what is the meaning of view()'s existence? Thank you!
One use case. Suppose you feed an image through several conv layers, and then you want to run the extracted features through a linear layer. The conv output is of shape
(batch_size, channels_out, height_out, width_out), now we want the linear layer to take all features from all channels and from every position in the image, but
nn.Linear only acts on the last dimension of its input. That is where view comes in handy.
linear_input = conv_output.view(batch_size, channels_out*height_out*width_out)
linear_input is now of shape
I think a reshape method is now available in torch version 0.4.0.