Autocast with batch normalization in Pytorch model.eval() returns NaNs

I am working in Pytorch 2.0.0

Following the usual instructions for automatic mixed precision (AMP), training is done using

# for autocasting
    scaler = torch.cuda.amp.GradScaler()
## ....

    for inputs, labels in loader:
        if inputs.device != device:
            inputs, labels = inputs.to(device), labels.to(device)

##  modify to use autocasting as described at
##  https://pytorch.org/docs/stable/amp.html
        # Set the parameter gradients to zero
        optimizer.zero_grad()

        # Forward pass, backward pass, optimize
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(inputs)
            loss_output = loss(outputs, labels)
        scaler.scale(loss_output).backward()
        scaler.step(optimizer)
        scaler.update()

This seems to work – the model learns almost as well with AMP as without. This is true for both models with and without batch normalization.

Model evaluation using an independent validation sample works perfectly well for a model without batch normalization. It produces NaNs when the model includes batch normalizations. The relevant code snippet is

# switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for inputs, labels in loader:
            if inputs.device != device:
                inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                val_outputs = model(inputs)
                loss_output = loss(val_outputs, labels)

            total_loss += loss_output.data.item()

The code runs through 2 epochs, alternating between training and validation samples. Even though the validation losses are all NaNs, the training starts perfectly well at the beginning of the second epoch, and the learning continues to progress.

Has anyone else observed this type of behavior?

Thanks in advance for any suggestions.

Mike

An earlier version of the model with batch normalization worked in FP32 but not in AMP. Getting the training to work in AMP required “flattening” the model so each batch normalization is a model.child, not the child of a model.child. This approach was suggested by a solution to a similar problem in the apex.amp framework from nVidia. While this solved the model.train() problem, it did not solve the model.eval() problem.

This would indicate that a forward pass during the training created an overflow and thus updated the running stats of the batchnorm layers with invalid values. Note that overflows in the forward pass are not expected in mixed-precision training and it usually indicates a broken training.