How does nn.Embedding work?

Does PyTorch treat backpropagation of (1-hot input to linear layer) the same as (index selection of embedding)?

I’m guessing that PyTorch will calculate the gradient for all entries of the linear layer and all but one will be zero given the 1 hot input. (ie: lots of computation for a large linear layer). Or does Pytorch optimize this out?

Will Pytorch do the same for embedding or will PyTorch initialize and backpropagate only to the index embedding?

Assuming limited GPU memory and large CPU memory. Do both share the same minimal (> zero) amount of data that can be sent to GPU?

1 Like