[Solved]Simple question about keep dim when slicing the tensor

Hi,
I am new in pyTorch however, I use torch in previous.

Here is a question bother me that how to slice the tensor and keep their dims in pytorch?

In torch I could write down like that:

val = torch.rand(4,3,256,256);
val_keep = val[{{},{1},{},{}}] # output is (4x1x256x256)
val_notkeep = val[{{},1,{},{}}] # output is (4x256x256)

however, it seems python transformed the dims automaticly in pytorch?

val = torch.rand(4,3,256,256);
val_notkeep = val[:,1,:,:] #output is (4x256x256)

so how to slice the tensor and keep the dims in pytorch?

6 Likes

Hi,

in Python (in addition to pytorch this also works with lists, numpy arrays) you can do this by using a “one-element slice”

val_keep = val[:,1:2,:,:]

Best regards

Thomas

18 Likes

thanks a lot! it works great!

Is this the recommended practice? What if you have many dimensions? Thanks!

3 Likes

Another option is to use unsqueeze:

import torch

t = torch.ones(10, 5, 2)
t[:, 1, ...].unsqueeze(1)
# or
t.moveaxis(1, 0)[1].unsqueeze(1) 

Note that you can use the ellipsis ... to instead of : for the remaining dimensions.

And yet another way of doing it, which has my (personal) preference:

torch.randn(8,1024,256)[0].shape
> torch.Size([1024, 256])

torch.randn(8,1024,256)[[0]].shape
> torch.Size([1, 1024, 256])

Or more general:

torch.randn(8,1024,256)[:,0,:].shape
> torch.Size([8, 256])

torch.randn(8,1024,256)[:,[0],:].shape
> torch.Size([8, 1, 256])
1 Like

Just read a comment regarding to this kind of slicing for numpy.ndarray, which claimed that this operation would copy the data.

Is it true though? Especially for PyTorch?