Training with Half Precision

It works, but you want to make sure that the BatchNormalization layers use float32 for accumulation or you will have convergence issues. You can do that by something like:

model.half()  # convert to half precision
for layer in model.modules():
  if isinstance(layer, nn.BatchNorm2d):
    layer.float()

Then make sure your input is in half precision.

Christian Sarofeen from NVIDIA ported the ImageNet training example to use FP16 here:

We’d like to clean-up the FP16 support to make it more accessible, but the above should be enough to get you started.

27 Likes