Select rows from a 2D tensor

I have a tensor x :

x.shape
torch.Size([90, 50])
dtype=torch.float64, device='cuda:0'

and I need to select rows defined by the list:

loc = [0, 0, 0, 1, 0, 1, 1, 0, 1,..., 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1]
# len(loc) is 90

and store the selected rows into tensor y
I think there is this function torch.masked_scatter_(), but I did not manage to use it successfully!

2 Likes

You can do it similar as you would do it with numpy indexing. Like so:

 import torch
 x = torch.rand(5,4)
 loc = torch.ByteTensor([0,1,0,0,1])
 y = x[loc]

Storing in y the 2nd and 5th rows of the x tensor as indicated by the ones in your loc tensor. Is this what you need?

5 Likes

Yes Diego, it worked. Thanks a bunch.

2 Likes

import torch
xy = torch.rand(5,4)
loc = torch.tensor([False,True,False,False,True])
y = xy[loc]
print(y)

Worked for me like this. Hope maybe useful for anyone else. Thanks to @Diego