Select rows from a 2D tensor

You can do it similar as you would do it with numpy indexing. Like so:

 import torch
 x = torch.rand(5,4)
 loc = torch.ByteTensor([0,1,0,0,1])
 y = x[loc]

Storing in y the 2nd and 5th rows of the x tensor as indicated by the ones in your loc tensor. Is this what you need?

5 Likes