Transformers most often have as input the addition of something and a position embedding.
For example, position 1 to 128 represented as torch.nn.Embedding(num_embeddings=128
.
I never see torch.nn.Linear
to project a float position to embedding. Nor do I see the sparce flag set for the embedding.
If they (Linear and Embedding) are essentially the same, I would assume some people would choose the linear projection (cleaner in my mind when the embedding is for position).
In non AI, non backprogration, a lookup can be implemented much more efficiently than multiplying an array by a mask. Especially so if the table is large.
I believe BERT usage of transformer use very large embedding (52K) to represent words in addition to embeddings for word position.
Scavenged the GitHub repo for PyTorch and found Embedding.cpp in the call path of nn.Embedding. No idea of how this code does its magic, but embedding_dense_backward_cpu has a bunch of if statements before adding grad_weights while Linear.cpp does a multiplication.
So I’m guessing embedding is much faster in backpropagation over linear especially when large embedding are used. If small embeddings, then essentially the same.
Hoping someone who understands the PyTorch implementation to say for sure.