It works, but you want to make sure that the BatchNormalization layers use float32 for accumulation or you will have convergence issues. You can do that by something like:
model.half() # convert to half precision
for layer in model.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
Then make sure your input is in half precision.
Christian Sarofeen from NVIDIA ported the ImageNet training example to use FP16 here:
We’d like to clean-up the FP16 support to make it more accessible, but the above should be enough to get you started.