Torch.cuda.amp, DataDistributedParallel and GAN training

I’m trying to train a GAN using torch.cuda.amp and DataDistributedParallel. Training works when mixed precision is disabled or with with a slight refactoring and using apex.amp and enabled mixed precision training.

What I had to do with apex.amp to make it work was:

# generator and discriminator instantiation
g_net = ...
d_net = ...

g_optim = torch.optim.Adam(g_net.parameters(), lr=0.0001, betas=(0, 0.9))
d_optim = torch.optim.Adam(d_net.parameters(), lr=0.0004, betas=(0, 0.9))

[d_net, g_net], [d_optim, g_optim] = apex.amp.initialize([d_net, g_net],
                                                         [d_optim, g_optim],
                                                          opt_level="O1",
                                                          num_losses=2)

g_pg = torch.distributed.new_group(range(torch.distributed.get_world_size()))
g_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(g_net, process_group=g_pg)
g_net = torch.nn.parallel.DistributedDataParallel(
     g_net.cuda(args.local_rank),
     device_ids=[args.local_rank],
     output_device=args.local_rank,
     process_group=g_pg,
)

# similarly for discriminator
d_net = ...

When I change the order of DistributedDataParallel and apex.amp.initialize, i.e.

g_net = torch.nn.parallel.DistributedDataParallel(g_net, ....)
d_net = torch.nn.parallel.DistributedDataParallel(d_net, ....)

g_optim = torch.optim.Adam(g_net.parameters(), lr=0.0001, betas=(0, 0.9))
d_optim = torch.optim.Adam(d_net.parameters(), lr=0.0004, betas=(0, 0.9))

# !!! apex.amp.initialize needs the wrapped module !!!
[d_net, g_net], [d_optim, g_optim] = apex.amp.initialize([d_net.module, g_net.module],
                                                         [d_optim, g_optim],
                                                          opt_level="O1",
                                                          num_losses=2)

the code is running but the final generator outputs rubbish. So I think the issue is the order of casting a module to mixed precision and distributing it between multiple processes, which we have no control over with the new torch.cuda.amp. The documentation at https://pytorch.org/docs/master/notes/amp_examples.html#distributeddataparallel-one-gpu-per-process, unfortunately lacks specifics or an example.

Specifically for GAN training I’ve got two more questions.

  1. Is there a way to check whether optimizer steps were skipped due to NaNs? For GANs it could potentially critical that generator and discriminator update in lock-step, without one skipping steps and the other doesn’t.
  2. I checked out the code of torch.cuda.amp.GradScaler and it seems uses the same scale for all inputs. For GANs possibly discriminator and generator need different scales. In this case, can I just create two scalers?

Since you don’t have to initialize the model for mixed-precision training in native AMP, you should wrap the model into DDP and apply the autocast and scaling as described in the examples.

  1. I don’t know, if there is a clean way currently (and think we should add a utility function for it). At the moment you could check, if the new scale factor was updated after calling scaler.step(optimizer).

  2. Yes, that should be possible, but please let us know, if you encounter any issues.

2 Likes