Hi Max!
Maxpim:
I want to index a 2D tensor [N, S] with a 3D tensor [N, P, K] with indices in range [0, S-1], the output should be [N, P, K].
…
index = torch.randint(0, 5, size = [100, 20, 15])
X = torch.randn(size = [100, 5])
Based on my simplest interpretation of what you mean by “indexing,” I expect
you want gather()
.
Consider (with smaller tensors for convenience):
>>> import torch
>>> print (torch.__version__)
2.2.1
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> index = torch.randint (0, 3, size = [4, 2, 6])
>>> X = torch.randn (size = [4, 3])
>>>
>>> output = X.unsqueeze (1).expand (4, 2, 3).gather (2, index)
>>>
>>> index
tensor([[[2, 2, 1, 2, 0, 2],
[1, 2, 2, 0, 2, 1]],
[[0, 1, 0, 2, 0, 2],
[0, 1, 2, 1, 0, 1]],
[[0, 2, 2, 2, 0, 2],
[2, 1, 0, 2, 2, 2]],
[[1, 2, 1, 0, 0, 0],
[0, 0, 2, 1, 0, 2]]])
>>> X
tensor([[ 0.6925, -0.6998, 0.3889],
[ 0.5154, 0.3058, -1.3593],
[-0.2669, 0.2686, -1.0352],
[ 0.3725, -1.0863, 0.0871]])
>>> output
tensor([[[ 0.3889, 0.3889, -0.6998, 0.3889, 0.6925, 0.3889],
[-0.6998, 0.3889, 0.3889, 0.6925, 0.3889, -0.6998]],
[[ 0.5154, 0.3058, 0.5154, -1.3593, 0.5154, -1.3593],
[ 0.5154, 0.3058, -1.3593, 0.3058, 0.5154, 0.3058]],
[[-0.2669, -1.0352, -1.0352, -1.0352, -0.2669, -1.0352],
[-1.0352, 0.2686, -0.2669, -1.0352, -1.0352, -1.0352]],
[[-1.0863, 0.0871, -1.0863, 0.3725, 0.3725, 0.3725],
[ 0.3725, 0.3725, 0.0871, -1.0863, 0.3725, 0.0871]]])
Best.
K. Frank