I want to implement a Module that has a lookup table which is manually updated with module outputs. The idea is in this paper: https://arxiv.org/pdf/2008.01466.pdf
If I use DDP, different GPU copies of the table (nn.Embedding, require_grad=False) could be different. But I want a synced table among different GPUs. What’s the best way to implement this?
I can think of keeping this table in CPU memory, and exchange data between CPU and GPUs, but no sure if this will affect the training speed a lot.