How to sync a manually updated lookup table in DDP

Hi,

I want to implement a Module that has a lookup table which is manually updated with module outputs. The idea is in this paper: https://arxiv.org/pdf/2008.01466.pdf

If I use DDP, different GPU copies of the table (nn.Embedding, require_grad=False) could be different. But I want a synced table among different GPUs. What’s the best way to implement this?

I can think of keeping this table in CPU memory, and exchange data between CPU and GPUs, but no sure if this will affect the training speed a lot.

Thanks

@pritamdamania87 to follow up here.

ouch sounds nasty… if the tables are pytorch tensors, then I think you should be able to do that with torch.distributed.reduce and gather etc (detectron2/MODEL_ZOO.md at main · facebookresearch/detectron2 · GitHub)