Is there a convenient way of reshaping only last n dimensions?


T = torch.randn(u,v,w,x,y,z)

where u,v,w,x,y,z are some dimensions. If I want to reshape this as: u,v,w,x,y*z I can write:

T = torch.randn(u,v,w,x,y,z).view(-1,v,w,x,y*z)

But I cannot write:

T = torch.randn(u,v,w,x,y,z).view(-1,y*z)

Since this will flatten all previous dimensions.
Is there some convenient notation like:

T = torch.randn(u,v,w,x,y,z).view_last(y*z)
U = T.view_last(y,z)

Currently in PyTorch, that just reshapes the final dimensions?
The reason I am interested in this is the case where u,v,w,x are not known ahead of time, and I would rather avoid doing a T.shape.

U can try torch.flatten, which will flatten the given dimension.

1 Like

Thanks for the suggestion, this does solve my issue in the case where I want to restrict dimensions, but not when I want to extend them. (i added an additional example to clarify my post)

Perhaps it is not a good idea to freely reshape the last dimensions since it might be ambiguous? Haven’t really thought about it.

what do u mean by extending the dimension? I didn’t find ur example :frowning:

T = torch.randn(3,4,5,6,7,8)
all_but_last_two_dims = T.size()[:-2]
U = T.view(*all_but_last_two_dims, -1)

I don’t think this is the most ideal solution especially if you want to flatten dimensions in the middle of a tensor but for your use case this should work.

T = torch.randn(u,v,w,x,y,z).view_last(y*z)
U = T.view_last(y,z)

Basically I want to first flatten, and then “re-inflate”

Well, in my example you could just save the last two dims and then expand those. I’m not totally sure what you find wrong with my answer? It’s a bit confusing because you are using a function that doesn’t actually exist and you never save the original tensor T. I think your example should look like this:

original_tensor = torch.randn(u,v,w,x,y,z)
T = original_tensor.view_last(y*z)
U = original_tensor.view_last(y,z)
# my suggestion
rem_dims, last_2_dims = original_tensor.size()[:-2], original_tensor.size()[-2:]
T = original_tensor.view(*rem_dims, -1)
U = T.view(*rem_dims, *last_2_dims)

Alternative Answer

Named Tensors and specifically this





This seems incorrect to me, you cannot view something which is


unlesss u,v,w,x=1

Sorry, didn’t mean to suggest that your answer was wrong, however the notation you are using is slightly clunky.

I think my original question is perhaps a bit unclear, but what I want to avoid is having to explicitly mention the other dimensionalities, i.e. if there is an input which Im only certain of the last two dimensions I dont want to have to do a T.shape to get the others. But thanks for the suggestion.

I’m afraid it’s not implemented. Tensors doesn’t save previous shapes, as soon as you reshape it, the info about previous dimensionality is lost.
The most you can do is something like @dhpollack suggested, which would be manually saving the dimensionality.
Given a tensor T

T_flatten = T.view(*T.shape[:-2],-1)
T_unflatten = T_flatten.view(*T_flatten.shape[:-1],y,z)

As I mentioned as soon as you compute T_flatten values of y and z are lost unless you save them.

T_flatten = T.view(*T.shape[:-2],-1)
T_unflatten = T_flatten.view(*T_flatten.shape[:-1],*T.shape[-2:])

So in short if you don’t overwrite T, you can do it that way.
If you overwrite T you would need to save y and z

1 Like