What is PyTorch equivalent of embedding_lookup() function in TensorFlow (for 2D indices)?

  • Operating System: Windows 10
  • Python Version: 3.7.11
  • PyTorch Version: 1.10.1

I have two below tensors:

import torch

embedding_vectors = torch.tensor([
    [0.01, 0.02, 0.03], 
    [0.07, 0.08, 0.04], 
    [0.05, 0.09, 0.06],
    [0.51, 0.92, 0.67],
    [0.55, 0.99, 0.64],
    [0.17, 0.23, 0.85],
    [0.45, 0.66, 0.31],
    [0.01, 0.07, 0.92],
    [0.25, 0.56, 0.32]
])

indices = torch.tensor([
        [0, 2], 
        [4, 5], 
        [6, 0]
])

I want to map the values in indices variable to rows in embedding_vectors variable, so I expect bellow tensor as output :

[
    [[0.01, 0.02, 0.03], [0.05, 0.09, 0.06]], 
    [[0.55, 0.99, 0.64], [0.17, 0.23, 0.85]], 
    [[0.45, 0.66, 0.31], [0.01, 0.02, 0.03]]
]

Question:

  • Does PyTorch have built-in function to do this as same as tf.nn.embedding_lookup(embedding_vectors, indices) in tensorflow?
  • If not, how can I do this?

I used torch.index_select(embedding_vectors , 0, indices) but it says that it expect a vector as indices while my indices variable has 2 dimension.

You could use the functional API via:

F.embedding(indices, embedding_vectors)

or assign embedding_vectors to the .weight attribute of an nn.Embedding layer.

1 Like