Select elements of a 2-d tensor by rows and columns


(Mahsa) #1

I am new to pytorch. I have a 2-d tensor and a list of rows and columns, I want to select elements in the pair of rows and columns like follow:

x = [[1,2,3]
      ,[4,5,6],
       [7,8,9]]
row = [0,1], col=[1,2] 

I want to have:

output = [1,6]

(Doug Friedman) #2

Sounds like you want index_select or masked_select?

https://pytorch.org/docs/stable/torch.html#torch.index_select

https://pytorch.org/docs/stable/torch.html#torch.masked_select


(Mahsa) #3

My input is too large and I have memory problem to create a mask from rows and columns.

index_select and gather needs to select data on a specific dimension.


#6

If I understand the question correctly, the out should be:
out = [2,6]
and can be done like:

x = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
row = torch.tensor([0,1])
col= torch.tensor([1,2])
res=[]
for idx in range(len(row)):
    res.append(x[row[idx]][col[idx]])
print(res)
out = torch.tensor(res)
print(out)

[tensor(2), tensor(6)]
tensor([2, 6])