Indexing 2D tensor using 3D tensor

Hi,

I want to index a 2D tensor [N, S] with a 3D tensor [N, P, K] with indices in range [0, S-1], the output should be [N, P, K]. Somehow I am stuck and I don’t know how to do it, here is a minimal example:

index = torch.randint(0, 5, size = [100, 20, 15])
X = torch.randn(size = [100, 5])

What is the point of your K dimension on your index? Doesn’t seem you’re using it.

sorry my bad, the output should be [N, P, K] and the indices are in the range of [0,S-1]

Hi Max!

Based on my simplest interpretation of what you mean by “indexing,” I expect
you want gather().

Consider (with smaller tensors for convenience):

>>> import torch
>>> print (torch.__version__)
2.2.1
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> index = torch.randint (0, 3, size = [4, 2, 6])
>>> X = torch.randn (size = [4, 3])
>>>
>>> output = X.unsqueeze (1).expand (4, 2, 3).gather (2, index)
>>>
>>> index
tensor([[[2, 2, 1, 2, 0, 2],
         [1, 2, 2, 0, 2, 1]],

        [[0, 1, 0, 2, 0, 2],
         [0, 1, 2, 1, 0, 1]],

        [[0, 2, 2, 2, 0, 2],
         [2, 1, 0, 2, 2, 2]],

        [[1, 2, 1, 0, 0, 0],
         [0, 0, 2, 1, 0, 2]]])
>>> X
tensor([[ 0.6925, -0.6998,  0.3889],
        [ 0.5154,  0.3058, -1.3593],
        [-0.2669,  0.2686, -1.0352],
        [ 0.3725, -1.0863,  0.0871]])
>>> output
tensor([[[ 0.3889,  0.3889, -0.6998,  0.3889,  0.6925,  0.3889],
         [-0.6998,  0.3889,  0.3889,  0.6925,  0.3889, -0.6998]],

        [[ 0.5154,  0.3058,  0.5154, -1.3593,  0.5154, -1.3593],
         [ 0.5154,  0.3058, -1.3593,  0.3058,  0.5154,  0.3058]],

        [[-0.2669, -1.0352, -1.0352, -1.0352, -0.2669, -1.0352],
         [-1.0352,  0.2686, -0.2669, -1.0352, -1.0352, -1.0352]],

        [[-1.0863,  0.0871, -1.0863,  0.3725,  0.3725,  0.3725],
         [ 0.3725,  0.3725,  0.0871, -1.0863,  0.3725,  0.0871]]])

Best.

K. Frank