Look up Table in Pytorch

Hi

I am new to Pytorch and I was wondering if the following can be done.
If I have a 2D tensor of indices for example:

tensor([[1, 2, 4],
[4, 2, 0],
[0, 1, 3]])

and a 1D tensor of values

tensor([6, 8, 9, 11, 23])

I would like to get an output of

tensor([[8, 9, 23],
[23, 9, 6],
[6, 8, 11]])

Thanks

Yes, you can directly index the tensor:

x = torch.tensor([6, 8, 9, 11, 23])
idx = torch.tensor([[1, 2, 4],
                    [4, 2, 0],
                    [0, 1, 3]])


ret = x[idx]
print(ret)
# tensor([[ 8,  9, 23],
#         [23,  9,  6],
#         [ 6,  8, 11]])
1 Like