How does the nn.Embedding module relate intuitively to the idea of an embedding in general?

So, I’m having a hard time understanding nn.Embedding. Specifically, I can’t connect the dots between what I understand about embeddings as a concept and what this specific implementation is doing.

My understanding of an embedding is that it is a smaller dimension representation of some larger dimension data point. So it maps data in N-d to a M-d latent/embedding space such that M < N.

As I understand it, this mapping is achieved through the learning process, as in an auto-encoder. The encoder learns the optimal embedding so that the decoder can reconstruct the original input.

So my question is, how does this relate to nn.Embedding module:

A simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.

does this layer “learn” a lower dimensional representation of a larger input space? Or is it something else entirely?

1 Like

In the general case, an embedding is, indeed, a mapping from a N-d space to a M-d space. The simplest form of embedding is therefore a matrix A of shape (M,N) (but there are also non linear embeddings), such that the transformation is y = A * x.

However, in the field of Neural Machine Translation (NMT), the words/tokens are not represented by dense vectors like x = [0.556, 0.2465, 0.015, ...], but rather by one-hot encodings, like x = [0, 0, 0, 1, 0, 0, 0, ...]. Hence, they are uniquely identified by the index of the 1 in x.

Therefore, instead of doing a product A * x, you can lookup the row in A corresponding to the index, like y = A[x], which is more efficient.

That is exactly what nn.Embedding implements: it is a matrix which is used as a lookup table for the word indices.

1 Like

Okay, that makes a little sense to me. Let me see if I can put what you said in my own words, and you tell me if I understood correctly:

In the linear sense, an embedding is just PCA…

Then there are two cases… Let’s say a sequence of numerical data (like stock prices) and a sequence of characters in a sentence.

In the first case y = Ax is easy. But with the string of characters you have to one-hot-encode each character, so then you get a sequence of sparse vectors. This isn’t so easy to transform to a lower dimension because it’s not a simple column vector. So nn.Embedding is the solution to this problem?

So in my personal use case, I want to do anomaly detection through reconstruction error. I’m trying to train an autoencoder to reconstruct timeseries data. This, I think, falls into the first category where the sequence is a dense vector.

How does that affect my use of nn.Embedding?

That is pretty much it! Although the one-hot-encoding is more likely to be “word-level” than “character-level”, but I might be wrong.

Anyway, in the case of stock prices, you indeed have a sequence of dense vectors already (I guess). You cannot represent each vector by an index and, therefore, you need a y = Ax transformation, as you said.

But nn.Embedding doesn’t do that; nn.Linear does it! So just go ahead and use that :wink:

1 Like

Ah okay! To be clear, let’s say I have daily samples:

['hi temp', 'inches rain', 'low temp']

Originally, I just made an autoencoder using nn.Linear along with ReLU activations, and reconstructed each day’s sample (all three variables). The anomaly detection was pretty poor with this approach.

So instead, I’m trying to think about it as three timeseries/sequences basically… I’m trying to learn how to reconstruct sequences using LSTMs instead, following this article:

So in this case, I could use the nn.Embedding to as a way to have a dictionary of 3 separate embeddings, one for each of the 3 time-series?

Thank you for your help!

I’m not sure to understand what you mean by “reconstructing sequences”. Do you mean like detecting AND correcting the anomalies ? Anyway, I’m not really qualified for that, sry :frowning:

But, in the blog post you referenced, the author uses nn.Embedding because he works with words (language translation).

You case is more similar to this tutorial

Btw, you will note that this one doesn’t use nn.Embedding (or nn.Linear) but rather feeds the vector sequences directly into the LSTM(s).

1 Like

This is a perfect example for what I’m trying to do. Thank you!