I’m trying to train a GAN using torch.cuda.amp and DataDistributedParallel. Training works when mixed precision is disabled or with with a slight refactoring and using apex.amp and enabled mixed precision training.
What I had to do with apex.amp to make it work was:
# generator and discriminator instantiation
g_net = ...
d_net = ...
g_optim = torch.optim.Adam(g_net.parameters(), lr=0.0001, betas=(0, 0.9))
d_optim = torch.optim.Adam(d_net.parameters(), lr=0.0004, betas=(0, 0.9))
[d_net, g_net], [d_optim, g_optim] = apex.amp.initialize([d_net, g_net],
[d_optim, g_optim],
opt_level="O1",
num_losses=2)
g_pg = torch.distributed.new_group(range(torch.distributed.get_world_size()))
g_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(g_net, process_group=g_pg)
g_net = torch.nn.parallel.DistributedDataParallel(
g_net.cuda(args.local_rank),
device_ids=[args.local_rank],
output_device=args.local_rank,
process_group=g_pg,
)
# similarly for discriminator
d_net = ...
When I change the order of DistributedDataParallel and apex.amp.initialize, i.e.
g_net = torch.nn.parallel.DistributedDataParallel(g_net, ....)
d_net = torch.nn.parallel.DistributedDataParallel(d_net, ....)
g_optim = torch.optim.Adam(g_net.parameters(), lr=0.0001, betas=(0, 0.9))
d_optim = torch.optim.Adam(d_net.parameters(), lr=0.0004, betas=(0, 0.9))
# !!! apex.amp.initialize needs the wrapped module !!!
[d_net, g_net], [d_optim, g_optim] = apex.amp.initialize([d_net.module, g_net.module],
[d_optim, g_optim],
opt_level="O1",
num_losses=2)
the code is running but the final generator outputs rubbish. So I think the issue is the order of casting a module to mixed precision and distributing it between multiple processes, which we have no control over with the new torch.cuda.amp. The documentation at https://pytorch.org/docs/master/notes/amp_examples.html#distributeddataparallel-one-gpu-per-process, unfortunately lacks specifics or an example.
Specifically for GAN training I’ve got two more questions.
- Is there a way to check whether optimizer steps were skipped due to NaNs? For GANs it could potentially critical that generator and discriminator update in lock-step, without one skipping steps and the other doesn’t.
- I checked out the code of torch.cuda.amp.GradScaler and it seems uses the same scale for all inputs. For GANs possibly discriminator and generator need different scales. In this case, can I just create two scalers?