Select elements of a 2-d tensor by rows and columns

(Mahsa) #1

I am new to pytorch. I have a 2-d tensor and a list of rows and columns, I want to select elements in the pair of rows and columns like follow:

x = [[1,2,3]
row = [0,1], col=[1,2] 

I want to have:

output = [1,6]

(Doug Friedman) #2

Sounds like you want index_select or masked_select?

(Mahsa) #3

My input is too large and I have memory problem to create a mask from rows and columns.

index_select and gather needs to select data on a specific dimension.


If I understand the question correctly, the out should be:
out = [2,6]
and can be done like:

x = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
row = torch.tensor([0,1])
col= torch.tensor([1,2])
for idx in range(len(row)):
out = torch.tensor(res)

[tensor(2), tensor(6)]
tensor([2, 6])