Slice from 2D tensor using 2D indices

I have a 2D tensor A of size (n, d) where each row has d elements
and another 2D tensor mask of size (n, 10) where each row has 10 integers in range [0, d-1].

I wish to extract elements from the (n,d) tensor using indices of (n,10), to obtain
(n, 10) elements from A.

Is there a simple to do this?

Hi Torchman!

If I understand your use case, yes, you may use gather().

As a more general alternative, you may also use pytorch tensor indexing.

Here is an illustration of these two approaches:

>>> torch.__version__
'1.12.0'
>>>
>>> _ = torch.manual_seed (2021)
>>>
>>> n = 3
>>> d = 8
>>> k = 4
>>>
>>> A = torch.arange (n * d).reshape (n, d) + 1
>>> A
tensor([[ 1,  2,  3,  4,  5,  6,  7,  8],
        [ 9, 10, 11, 12, 13, 14, 15, 16],
        [17, 18, 19, 20, 21, 22, 23, 24]])
>>>
>>> mask = torch.randint (d, (n, k))
>>> mask
tensor([[4, 5, 1, 0],
        [5, 6, 6, 4],
        [7, 6, 3, 5]])
>>>
>>> A.gather (dim = 1, index = mask)
tensor([[ 5,  6,  2,  1],
        [14, 15, 15, 13],
        [24, 23, 20, 22]])
>>>
>>> ind0 = torch.arange (n).unsqueeze (-1).expand (n, k)
>>> ind0
tensor([[0, 0, 0, 0],
        [1, 1, 1, 1],
        [2, 2, 2, 2]])
>>>
>>> A[ind0, mask]
tensor([[ 5,  6,  2,  1],
        [14, 15, 15, 13],
        [24, 23, 20, 22]])

Best.

K. Frank

1 Like