As I know, DDP (DistributedDataParallel) works only when all parameters of the given module participate to calculate loss.
In case of nn.Embedding
, I think some parameters of the module can be not used in forward pass.
However, Transformers such as BERT works well with DDP.
I don’t understand how nn.Embedding
can work with DDP.
Even if not all indices of the embedding weight matrix are used, the parameter itself would still be used and would thus get a valid gradient (zeros for the unused indices), so DDP shouldn’t complain about it.