Calling DistributedDataParallel on multiple Modules?

I’m wondering if anyone has some insight into the effects of calling DDP twice for one process group initialization? Good example of this would be a GAN where there are two distinct models. Can they both safely be wrapped in DDP? I suppose a dummy container module could be made that encases both models and only requires a single DDP wrapper. I looked around but didn’t see anything posted about this.

Hi, I’m also curious about this question. Have you solved it?

1 Like

Using the same process group for multiple DDP wrapped modules may work, only if they are independently used, and a call to backwards doesn’t involve both models. For GANs this may be the case, where you alternate between training the discriminator and the generator. That said, if a single call to backward involves gradient accumulation for more than 1 DDP wrapped module, then you’ll have to use a different process group for each of them to avoid interference.

You can do this as follows:

pg1 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
model1 = DistributedDataParallel(
    create_model1(),
    device_ids=[args.local_rank],
    process_group=pg1)

pg2 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
model2 = DistributedDataParallel(
    create_model2(),
    device_ids=[args.local_rank],
    process_group=pg2)
1 Like

This is interesting. With most GAN architectures, the backward pass of the generator does indeed include both G and D. Generally you will get some prediction from the discriminator using the fake samples and then call backward on that, propagating that output through D, then G. It may actually be a requirement then that GANs use separate process groups for each model. Granted, in this architecture, weight updates only occur for one model at a time, but gradients should be accumulated for both D and G during the G backward/update stage.

If it is indeed the case that GANs need separate process groups for G and D, then that is something that definitely needs to be in the docs. I’ve had some strange training results while using DDP and this may be the cause.

@pietern Do you know if the interference you speak of would cause an exception, or just produce incorrect gradients?

I want to put together a test bed for this and see if there are indeed different gradients when using 1 vs 2 PGs.

@mdlockyer I think bad behavior would result in crashes or hangs. Calls to allreduce that are mixed and matched with different dimensionality across processes is a recipe for out of bound memory access and undefined behavior.

Would it be possible to use autograd.grad for the discriminator part and autograd.backward for the generator part? This would avoid gradient accumulation and reduction for the discriminator. You’d still have to stitch them together at the boundary of course.

@pietern that’s an interesting idea! Only backward() will trigger the reducer right? Off the top of your head, what would the stitch look like between grad and backward?

Yes, only backward() interacts with the reducer.

I imagine combining grad and backward like this (but YMMV and I don’t know if this works).

G_out_grad = torch.autograd.grad(D_loss, G_out)
torch.autograd.backward(G_out, G_out_grad)

This wouldn’t trigger any reduction for the discriminator and only do so for the generator.

Very cool. When I get a second I’m going to implement this and test it out.

1 Like