Training with Half Precision

We’ve developed a lightweight, open-source set of Pytorch tools to enable easier, more numerically stable mixed precision training: Mixed precision means that the majority of the network uses FP16 arithmetic (reducing memory storage/bandwidth demands and enabling Tensor Cores for gemms and convolutions), while a small subset of operations are executed in FP32 for improved stability.

Highlights include:

  • Amp, a tool that executes all numerically safe Torch functions in FP16, while automatically casting potentially unstable operations to FP32. Amp also automatically implements dynamic loss scaling. Amp is designed to offer maximum numerical stability, and most of the speed benefits of pure FP16 training.
  • FP16_Optimizer, an optimizer wrapper that automatically implements FP32 master weights for parameter updates, as well as static or dynamic loss scaling. FP16_Optimizer is designed to be minimally invasive (it doesn’t change the execution of Torch operations) and offer almost all the speed of pure FP16 training with significantly improved numerical stability.
  • apex.parallel.DistributedDataParallel, a distributed module wrapper that achieves high performance by overlapping computation with communication during backward(). Apex DistributedDataParallel is useful for both pure FP32 as well as mixed precision training.

Full API documentation can be found here.


Our examples page demonstrates the use of FP16_Optimizer and Apex DistributedDataParallel. Amp examples are coming soon, and Amp’s use is thoroughly discussed in its README.

Give Apex a try and let us know what you think!

sorry for double post, the forum page told me “new users may only post 2 links at a time” or something along those lines.


The link to csarofeen/examples does not work any more. You can find an example here: Fp16 on pytorch 0.4


Hi thanks for your explanation.
May I ask why the BN must use float32, does that mean BN us different from other layers, like conv, linear, etc?

I’d say the easiest way to use and not make a mistake is to use PyTorch Lightning with

This will train your model using 16-bit.

1 Like

Thanks @mcarilli. Apex was very useful to us in our project.

any suggestions on using float16 with transformers. Should I keep some layers in float32 just like batch-normalization is recommended to keep in float32?

I would generally recommend to use the automatic mixed precision package (via torch.cuda.amp), which uses casts the input to the appropriate dtype for each method.


okay thanks. Should we keep val_step under autocast scope as well for fair comparison between tr_loss & val_loss?

Yes, you can also use autocasting during the validation.
Especially if you plan on using it for the test dataset (or deployment) I would use it.

1 Like

I used torch.cuda.amp tools to training an u-net-like network but my loss function gave NaN. I guess this is overflow’s problem when using fp16. Can you give me some advice to overcome this? Thank you so much!

Could you check if the output of the model is already creating invalid values?
If so, could you check the intermediate activation values for any invalid values (e.g. using torch.isfinite(out).all()) to narrow down the first occurrence?

1 Like

What do you mean by “first occurrence”? I use torch.autograd.set_detect_anomaly(True) and the output said there is NaN problem with sqrtBackward, or addBackWard or CuDnnConvolutionBackward sometime at the 0th input.
I think this is the amp’s problem just because it doesn’t happen when I turn AMP off.
Thank you!

By “first occurrence” I meant the first activation which shows an invalid value to narrow down the operation.
Since you are apparently seeing different operations at the moment, this would help narrow down the offending operation (e.g. an eps value used in sqrt might be too small when using amp and could thus underflow).

1 Like

I feel more clear now!
Let assume that I have many layers that use sqrt operation. So how can I detect which one cause the overflow/underflow problem?
Is there any possible way to find it without modifying the forward pass of each layer to figure out the first occurrence?
Thank you!

You could use forward hooks as described here which would allow you to check the outputs without changing the forward function in case you are using nn.Modules.


Thank you so much. I will try and update the result!

I think I misunderstood the output log. At the first epoch, network finishes forward pass and gives loss value as a finite number (tensor(0.2221,device='cuda:0',grad_fn=L1LossBackward>))

This means there aren’t invalid inputs at any layers. I also use torch.isfinite(out).all() to check the activation output, the function give tensor(True, device='cuda:0').
So the problem is in backward pass, isn’t it? If so, whether I can use register_forward_hook to figure out what layer cause the NaN gradient and narrow it down? Is there any trade-off or consequence?

This is the output log:

tensor(0.2221, device='cuda:0', grad_fn=<L1LossBackward>)
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-98005322b3ff> in <module>()
     38       loss = criterion(output, target)
     39       print(loss)
---> 40     scaler.scale(loss).backward()
     41     scaler.step(optimizer)
     42     scaler.update()

1 frames
/usr/local/lib/python3.7/dist-packages/torch/autograd/ in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    147     Variable._execution_engine.run_backward(
    148         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 149         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeError: Function 'MulBackward0' returned nan values in its 0th output.

I used register_backward_hook to hook the input/output grad of the last layer of my network using this code snippet:

def printgradnorm(self, grad_input, grad_output):
    print('Inside ' + self.__class__.__name__ + ' backward')
    print('Inside class:' + self.__class__.__name__)
    print('grad_input: ', type(grad_input))
    print('grad_input[0]: ', type(grad_input[0]))
    print('grad_output: ', type(grad_output))
    print('grad_output[0]: ', type(grad_output[0]))
    print('grad_input size:', grad_input[0].size())
    print('grad_output size:', grad_output[0].size())
    print('grad_input norm:', grad_input[0].norm())
    print('grad_output norm:', grad_output[0].norm())



Without using AMP package, the output is:

loss value:  tensor(0.3254, device='cuda:0', grad_fn=<L1LossBackward>)
Inside Conv2d backward
Inside class:Conv2d

grad_input:  <class 'tuple'>
grad_input[0]:  <class 'torch.Tensor'>
grad_output:  <class 'tuple'>
grad_output[0]:  <class 'torch.Tensor'>

grad_input size: torch.Size([2, 16, 256, 256])
grad_output size: torch.Size([2, 16, 256, 256])
grad_input norm: tensor(0.0007, device='cuda:0')
grad_output norm: tensor(0.0007, device='cuda:0')

is finite:  tensor(True, device='cuda:0')
is finite:  tensor(True, device='cuda:0')

and using AMP gives this output:

loss value:  tensor(0.3358, device='cuda:0', grad_fn=<L1LossBackward>)
Inside Conv2d backward
Inside class:Conv2d

grad_input:  <class 'tuple'>
grad_input[0]:  <class 'torch.Tensor'>
grad_output:  <class 'tuple'>
grad_output[0]:  <class 'torch.Tensor'>

grad_input size: torch.Size([2, 16, 256, 256])
grad_output size: torch.Size([2, 16, 256, 256])
grad_input norm: tensor(45.2500, device='cuda:0', dtype=torch.float16)
grad_output norm: tensor(45.2500, device='cuda:0', dtype=torch.float16)

is finite:  tensor(True, device='cuda:0')
is finite:  tensor(True, device='cuda:0')

and to the next layer in the backward pass, the grad is NaN (inf).

How can I narrow down the exploding value in this case?
Once again, thank you so much!

You would have to be a bit careful when to check for invalid gradients during mixed-precision training.
The important part is that the forward pass is not creating invalid values, as this would point towards an overflow and you should then narrow it down using the aforementioned forward hooks.
However, based on your description it seems that the forward pass does not yield any invalid values, but anomaly detection triggers during the backward pass.
This is expected for the first few iterations when using the GradScaler with the default scale value.
The loss will be scaled by init_scale=65536.0 initially. This could overflow the gradients, the scaler.step(optimizer) will check for these invalid gradients, skip the optimizer.step() operation, and lower the scale value. The parameters would thus never be updated with the invalid gradients.
If you want to avoid the initially skipped steps, you could set a lower init_scale to avoid this behavior.

With that being said, in your first post you’ve mentioned that the “loss function gave NaN” values, which points towards the forward pass.

1 Like

Why float16 would cause convergence issue for BN?