Using apex AMP (Automatic Mixed Precision) with model parallelism

My model has a few LSTMs which run out of Cuda memory when run on large sequences with one GPU. So I shifted a few components of the model to another GPU. I tried 2 things with Apex AMP:

  1. Move the model components to another GPU before invoking amp.initialize . In this case, I get NaNs soon after first backpropagation.
  2. First invoke amp.initialize , and then move the model components to another GPU. In this case, its like the model backpropagation runs on a single GPU. It runs out of Cuda memory.

The model training runs fine without Apex, so I suppose I am missing some step where the loss is backpropagated on both GPUs. I looked through the documentations of Apex, however, it only talks about DataParallelism, and not ModelParallelism.

Any ideas?

Hey @Caesar, have you tried the native AMP in PyTorch? I haven’t tried that with model parallel yet, but if it does not work, that will be a bug that we need to fix.

https://pytorch.org/docs/stable/amp.html
https://pytorch.org/docs/stable/notes/amp_examples.html#amp-examples

Thanks for your response. I am constrained to use an older version of PyTorch which does not support AMP natively. So I am usingNVIDIA apex. https://github.com/NVIDIA/apex

I see.

cc AMP author @mcarilli

I don’t think the apex amp API supports this without complex/undocumented hacks. Apex amp is in maintenance mode now, no new features will be added.

However, torch.cuda.amp is designed to support model parallelism (ie different layers on different device) out of the box. Please consider upgrading if at all possible.

1 Like

Thanks for your response.

There’s a multi-GPU torch.cuda.amp.GradScaler test case that ensures ordinary GradScaler usage supports networks with layers on different devices.

torch.cuda.amp.autocast locally enables/disables autocast for all devices used by the invoking thread. (However, the autocast state is thread local, so if you spawn a thread to control each device, you must re-invoke autocast in the side thread(s). This affects usage with torch.nn.DataParallel and torch.nn.parallel.DistributedDataParallel with multiple GPUs per process.)