How to sync a manually updated lookup table in DDP


I want to implement a Module that has a lookup table which is manually updated with module outputs. The idea is in this paper:

If I use DDP, different GPU copies of the table (nn.Embedding, require_grad=False) could be different. But I want a synced table among different GPUs. What’s the best way to implement this?

I can think of keeping this table in CPU memory, and exchange data between CPU and GPUs, but no sure if this will affect the training speed a lot.


@pritamdamania87 to follow up here.

ouch sounds nasty… if the tables are pytorch tensors, then I think you should be able to do that with torch.distributed.reduce and gather etc (detectron2/ at main · facebookresearch/detectron2 · GitHub)