How to use nn.Embedding?


(손주형) #1

I have text tensor a with 64 x 4 x 3 x 30

and I want to make word embedding with nn.Embedding()
then which is right?

b = nn.Embedding(20000, 300)
b(a)
b = nn.Embedding(20000, 300)
a = a.view(-1, 30)
a = b(a)
a = a.view(-1, 4, 3, 30)

(Arul) #2

To me, this example works.

import torch, torch.nn as nn

embedding_dim = 5
num_words = 20

embedding = nn.Embedding(num_words, embedding_dim)
random_word_indices = torch.randint(low=0, high=20,size=(1,2,3,2)).view(1,2,3,2).long() 
word_embeddings = embedding(random_word_indices)

Here I assumed text tensor random_word_indices is of size 1x2x3x2.