[Solved]Simple question about keep dim when slicing the tensor

Hi,
I am new in pyTorch however, I use torch in previous.

Here is a question bother me that how to slice the tensor and keep their dims in pytorch?

In torch I could write down like that:

val = torch.rand(4,3,256,256);
val_keep = val[{{},{1},{},{}}] # output is (4x1x256x256)
val_notkeep = val[{{},1,{},{}}] # output is (4x256x256)

however, it seems python transformed the dims automaticly in pytorch?

val = torch.rand(4,3,256,256);
val_notkeep = val[:,1,:,:] #output is (4x256x256)

so how to slice the tensor and keep the dims in pytorch?

5 Likes

Hi,

in Python (in addition to pytorch this also works with lists, numpy arrays) you can do this by using a “one-element slice”

val_keep = val[:,1:2,:,:]

Best regards

Thomas

15 Likes

thanks a lot! it works great!

Is this the recommended practice? What if you have many dimensions? Thanks!

2 Likes

Another option is to use unsqueeze:

import torch

t = torch.ones(10, 5, 2)
t[:, 1, ...].unsqueeze(1)
# or
t.moveaxis(1, 0)[1].unsqueeze(1) 

Note that you can use the ellipsis ... to instead of : for the remaining dimensions.