Looking for an efficient way to index 2D tensor by another 2D tensor

I have a tensor, say

A = tensor([
                   [0, 0],
                   [0, 2],
                   [0, 3],
                   [0, 4],
                   [0, 5],
                   [0, 6],
                   [1, 0],
                   [1, 1],
                   [1, 4],
                   [1, 5],
                   [1, 6]
       ])

and the other tensor

b = tensor([[0, 2], [1, 2]])

I would like to find an efficient way to index into A by b such that the result is

result = tensor([[0, 3], [1, 4]])

That is, match A’s first column of last dim (i.e. [0,…,1…]) with b’s first column of the last dim (i.e. [0,1]) by their values and then use b’s second column (i.e. [2, 2]) to index A’s second column.

Thanks

Can’t give you a straight answer to this, but it sounds like you might want to separate the two steps… first do the 0:th column matching, then use e.g. torch.gather to gather values from A[:, 1] by indexing with b[:, 1].

Thanks for your suggestion.
I convert it into one dimensional problem with torch.nonzero and offset by mask sum, and work out a solution.
Instead of the original A, get a flatten version, like

A = tensor([[ 0],
			[ 2],
			[ 3],
			[ 4],
			[ 5],
			[ 7],
			[ 8],
			[11],
			[12]])

and also calculate the offsets along batch,

offset = tensor([[0],
				 [5],
				 [4]])

Similarly, get b

b = tensor([2, 2])

and

offset_b = b+offset.reshape(-1)[:-1]

Then

indices=A.reshape(-1)[offset_b]
1 Like