Calling DistributedDataParallel on multiple Modules?

Using the same process group for multiple DDP wrapped modules may work, only if they are independently used, and a call to backwards doesn’t involve both models. For GANs this may be the case, where you alternate between training the discriminator and the generator. That said, if a single call to backward involves gradient accumulation for more than 1 DDP wrapped module, then you’ll have to use a different process group for each of them to avoid interference.

You can do this as follows:

pg1 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
model1 = DistributedDataParallel(
    create_model1(),
    device_ids=[args.local_rank],
    process_group=pg1)

pg2 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
model2 = DistributedDataParallel(
    create_model2(),
    device_ids=[args.local_rank],
    process_group=pg2)
1 Like