# Indexing 2D tensor using 3D tensor

Hi,

I want to index a 2D tensor [N, S] with a 3D tensor [N, P, K] with indices in range [0, S-1], the output should be [N, P, K]. Somehow I am stuck and I don’t know how to do it, here is a minimal example:

``````index = torch.randint(0, 5, size = [100, 20, 15])
X = torch.randn(size = [100, 5])
``````

What is the point of your K dimension on your index? Doesn’t seem you’re using it.

sorry my bad, the output should be [N, P, K] and the indices are in the range of [0,S-1]

Hi Max!

Based on my simplest interpretation of what you mean by “indexing,” I expect
you want `gather()`.

Consider (with smaller tensors for convenience):

``````>>> import torch
>>> print (torch.__version__)
2.2.1
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> index = torch.randint (0, 3, size = [4, 2, 6])
>>> X = torch.randn (size = [4, 3])
>>>
>>> output = X.unsqueeze (1).expand (4, 2, 3).gather (2, index)
>>>
>>> index
tensor([[[2, 2, 1, 2, 0, 2],
[1, 2, 2, 0, 2, 1]],

[[0, 1, 0, 2, 0, 2],
[0, 1, 2, 1, 0, 1]],

[[0, 2, 2, 2, 0, 2],
[2, 1, 0, 2, 2, 2]],

[[1, 2, 1, 0, 0, 0],
[0, 0, 2, 1, 0, 2]]])
>>> X
tensor([[ 0.6925, -0.6998,  0.3889],
[ 0.5154,  0.3058, -1.3593],
[-0.2669,  0.2686, -1.0352],
[ 0.3725, -1.0863,  0.0871]])
>>> output
tensor([[[ 0.3889,  0.3889, -0.6998,  0.3889,  0.6925,  0.3889],
[-0.6998,  0.3889,  0.3889,  0.6925,  0.3889, -0.6998]],

[[ 0.5154,  0.3058,  0.5154, -1.3593,  0.5154, -1.3593],
[ 0.5154,  0.3058, -1.3593,  0.3058,  0.5154,  0.3058]],

[[-0.2669, -1.0352, -1.0352, -1.0352, -0.2669, -1.0352],
[-1.0352,  0.2686, -0.2669, -1.0352, -1.0352, -1.0352]],

[[-1.0863,  0.0871, -1.0863,  0.3725,  0.3725,  0.3725],
[ 0.3725,  0.3725,  0.0871, -1.0863,  0.3725,  0.0871]]])
``````

Best.

K. Frank