How does nn.Embedding work?

An Embedding layer is essentially just a Linear layer. So you could define a your layer as nn.Linear(1000, 30), and represent each word as a one-hot vector, e.g., [0,0,1,0,...,0] (the length of the vector is 1,000).

As you can see, any word is a unique vector of size 1,000 with a 1 in a unique position, compared to all other words. Now giving such a vector v with v[2]=1 (cf. example vector above) to the Linear layer gives you simply the 2nd row of that layer.

nn.Embedding just simplifies this. Instead of giving it a big one-hot vector, you just give it an index. This index basically is the same as the position of the single 1 in the one-hot vector.

51 Likes