Selecting entries of matrix based on on 1D tensor of indices


My question is: if I have a 2D tensor and a 1D LongTensor which stores a list of indices, then I want to select an entry from each row of the 2D tensor based on the 1D longTensor. How can I achieve that in PyTorch?

For example, a = [[1,2,3],[4,5,6],[7,8,9]], b = [2,1,0], then I would like to get [3, 5, 7]

Also, say if I torch.sum([3,5,7]) and then take the derivative of it, is it doable in the sense that the partial derivatives will be successfully calculated?

Thanks a lot!


This should do it:

x.gather(1, b.unsqueeze(1))

The gradient will be correct, as long as all values in b are unique.


The values in b shouldn’t have to be unique, since they’re indexing into separate rows?

Yes, of course. You’re right. They should be unique within a row, but it’s not a problem here.

Thanks a lot for your answer! It works now.