I am working in Pytorch 2.0.0
Following the usual instructions for automatic mixed precision (AMP), training is done using
# for autocasting scaler = torch.cuda.amp.GradScaler() ## .... for inputs, labels in loader: if inputs.device != device: inputs, labels = inputs.to(device), labels.to(device) ## modify to use autocasting as described at ## https://pytorch.org/docs/stable/amp.html # Set the parameter gradients to zero optimizer.zero_grad() # Forward pass, backward pass, optimize with torch.autocast(device_type='cuda', dtype=torch.float16): outputs = model(inputs) loss_output = loss(outputs, labels) scaler.scale(loss_output).backward() scaler.step(optimizer) scaler.update()
This seems to work – the model learns almost as well with AMP as without. This is true for both models with and without batch normalization.
Model evaluation using an independent validation sample works perfectly well for a model without batch normalization. It produces NaNs when the model includes batch normalizations. The relevant code snippet is
# switch to evaluate mode model.eval() with torch.no_grad(): for inputs, labels in loader: if inputs.device != device: inputs, labels = inputs.to(device), labels.to(device) # Forward pass with torch.autocast(device_type='cuda', dtype=torch.float16): val_outputs = model(inputs) loss_output = loss(val_outputs, labels) total_loss += loss_output.data.item()
The code runs through 2 epochs, alternating between training and validation samples. Even though the validation losses are all NaNs, the training starts perfectly well at the beginning of the second epoch, and the learning continues to progress.
Has anyone else observed this type of behavior?
Thanks in advance for any suggestions.
An earlier version of the model with batch normalization worked in FP32 but not in AMP. Getting the training to work in AMP required “flattening” the model so each batch normalization is a model.child, not the child of a model.child. This approach was suggested by a solution to a similar problem in the apex.amp framework from nVidia. While this solved the model.train() problem, it did not solve the model.eval() problem.