NaN Loss Issues with Precision 16 in PyTorch Lightning GAN Training

Hello,
I am experiencing issues applying Precision 16 in PyTorch Lightning.
I am optimizing the Generator and Discriminator using net_G_A and net_D_A, and optimizing patchNCELoss using net_F_A.

I have confirmed on documents that manual backward is essential when using multi-optimizers, and the code runs without issues with precision 32.

However, when I set precision=16 in the trainer to apply precision 16, I encounter a problem where the loss values become NaN. This NaN issue always occurs in the GAN loss part. To address this, I have tried the following:

  1. Applying float() to the data and the resulting loss values when calculating the GAN loss.
  2. Manually implementing auto casting and gradient scaling in the training_step.
  3. Clipping gradients.

Despite trying all these methods, I couldn’t avoid the NaN values in the loss.
Is there anything I might be missing in this situation?

Thank you very much for reading the long question!

class ProposedSynthesisModule(BaseModule_AtoB):

    def __init__(
        self,
        netG_A: torch.nn.Module,  # Generator
        netD_A: torch.nn.Module,  # Discriminator
        netF_A: torch.nn.Module,  # For PatchNCELoss
        optimizer,
        params,
        scheduler,
        *args,
        **kwargs: Any
    ):
        super().__init__(params, *args, **kwargs)
        
        # assign generator
        self.netG_A = netG_A
        self.netD_A = netD_A
        self.netF_A = netF_A

        self.save_hyperparameters(logger=False)
        self.automatic_optimization = False # perform manual
        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.optimizer = optimizer
        self.params = params
        self.scheduler = scheduler

        # assign contextual loss
        style_feat_layers = {
            "conv_2_2": 1.0,
            "conv_3_2": 1.0,
            "conv_4_2": 1.0,
            "conv_4_4": 1.0
        }

        # loss function
        self.criterionContextual = Contextual_Loss(style_feat_layers)
        self.criterionGAN = GANLoss(gan_type='lsgan')
        self.criterionNCE = PatchNCELoss(False, nce_T=0.07, batch_size=params.batch_size)

        # PatchNCE specific initializations
        self.nce_layers = [0,2,4,6] # range: 0~6
        self.flip_equivariance = params.flip_equivariance
        # self.flipped_for_equivariance = False

    def backward_G(self, real_a, real_b, fake_b, lambda_style, lambda_nce):

        pred_fake = self.netD_A(fake_b.detach())
        loss_gan = self.criterionGAN(pred_fake, True)
        assert not torch.isnan(loss_gan).any(), "GAN Loss is NaN"

        ## Contextual loss
        loss_style_B = self.criterionContextual(real_b, fake_b)
        loss_style =  loss_style_B * lambda_style
        assert not torch.isnan(loss_style).any(), "Contextual Loss is NaN"

        # PatchNCE loss (real_a, fake_b)
        n_layers = len(self.nce_layers)
        feat_b = self.netG_A(fake_b, real_a, self.nce_layers, encode_only=True)

        flipped_for_equivariance = np.random.random() < 0.5
        if self.flip_equivariance and flipped_for_equivariance:
            feat_b = [torch.flip(fb, [3]) for fb in feat_b]

        feat_a = self.netG_A(real_a, real_b, self.nce_layers, encode_only=True)
        feat_a_pool, sample_ids = self.netF_A(feat_a, 256, None)
        feat_b_pool, _ = self.netF_A(feat_b, 256, sample_ids)

        total_nce_loss = 0.0
        for f_a, f_b in zip(feat_b_pool, feat_a_pool):
            loss = self.criterionNCE(f_a, f_b) * lambda_nce
            total_nce_loss += loss.mean()
        loss_nce = total_nce_loss / n_layers
        assert not torch.isnan(loss_nce).any(), "NCE Loss is NaN"

        loss_G = loss_gan + loss_style + loss_nce
        assert not torch.isnan(loss_G).any(), "Total Loss is NaN"

        return loss_G, loss_gan, loss_style, loss_nce

    def training_step(self, batch: Any, batch_idx: int):
        optimizer_G_A, optimizer_D_A, optimizer_F_A = self.optimizers()
        real_a, real_b, fake_b = self.model_step(batch)
        
        with optimizer_G_A.toggle_model():
            # with optimizer_F_A.toggle_model():
            loss_G, loss_gan, loss_style, loss_nce = self.backward_G(real_a, real_b, fake_b, self.params.lambda_style, self.params.lambda_nce)
            self.manual_backward(loss_G)
            self.clip_gradients(
                optimizer_G_A, gradient_clip_val=0.5, gradient_clip_algorithm="norm"
            )
            self.clip_gradients(
                optimizer_F_A, gradient_clip_val=0.5, gradient_clip_algorithm="norm"
            )
            optimizer_G_A.step()
            optimizer_F_A.step()
            optimizer_G_A.zero_grad()
            optimizer_F_A.zero_grad()

        # self.loss_G = loss_G.detach() * 0.1 + self.loss_G * 0.9
        self.log("G_loss", loss_G.detach(), prog_bar=True)
        self.log("loss_gan", loss_gan.detach(), prog_bar=True)
        self.log("loss_style", loss_style.detach(), prog_bar=True)
        self.log("loss_nce", loss_nce.detach(), prog_bar=True)

        with optimizer_D_A.toggle_model(): 
            loss_D_A = self.backward_D_A(real_b, fake_b)
            self.manual_backward(loss_D_A)
            self.clip_gradients(
                optimizer_D_A, gradient_clip_val=0.5, gradient_clip_algorithm="norm"
            )
            optimizer_D_A.step()
            optimizer_D_A.zero_grad()
        self.log("Disc_A_Loss", loss_D_A.detach(), prog_bar=True)

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
        """
        optimizers = []
        schedulers = []
        
        optimizer_G_A = self.hparams.optimizer(params=self.netG_A.parameters())
        optimizers.append(optimizer_G_A)
        optimizer_D_A = self.hparams.optimizer(params=self.netD_A.parameters())
        optimizers.append(optimizer_D_A)
        optimizer_F_A = self.hparams.optimizer(params=self.netF_A.parameters())
        optimizers.append(optimizer_F_A)

        if self.hparams.scheduler is not None:
            scheduler_G_A = self.hparams.scheduler(optimizer=optimizer_G_A)
            schedulers.append(scheduler_G_A)
            scheduler_D_B = self.hparams.scheduler(optimizer=optimizer_D_A)
            schedulers.append(scheduler_D_B)
            scheduler_F_A = self.hparams.scheduler(optimizer=optimizer_F_A)
            schedulers.append(scheduler_F_A)
            
        return optimizers, schedulers
class GANLoss(nn.Module):

    def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0, reduction='mean'):
        super(GANLoss, self).__init__()
        self.gan_type = gan_type
        self.loss_weight = loss_weight
        self.real_label_val = real_label_val
        self.fake_label_val = fake_label_val

        if self.gan_type == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss(reduction=reduction)
        elif self.gan_type == 'lsgan':
            self.loss = nn.MSELoss(reduction=reduction)
        elif self.gan_type == 'wgan':
            self.loss = self._wgan_loss
        elif self.gan_type == 'wgan_softplus':
            self.loss = self._wgan_softplus_loss
        elif self.gan_type == 'hinge':
            self.loss = nn.ReLU()
        elif self.gan_type == 'swd':
            self.loss = self._slicedWassersteinDistance_loss
        # elif self.gan_type == 'bce':
        #     self.loss = nn.BCELoss()
        else:
            raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')

    def forward(self, input, target_is_real, is_disc=False):
        target_label = self.get_target_label(input, target_is_real)
        
        if self.gan_type == 'swd' and not is_disc:
            swd_loss = self.loss(input, target_label)
            return swd_loss * self.loss_weight
        
        if self.gan_type == 'hinge':
            if is_disc:  # for discriminators in hinge-gan
                input = -input if target_is_real else input
                loss = self.loss(1 + input).mean()
            else:  # for generators in hinge-gan
                loss = -input.mean()
        else:  # other gan types
            loss = self.loss(input, target_label)

        # loss_weight is always 1.0 for discriminators
        return loss if is_disc else loss * self.loss_weight

Hey there,
It seems like you’re encountering NaN loss issues when applying Precision 16 in PyTorch Lightning, especially in the GAN loss part of your training. Despite your attempts at various solutions like applying float(), manually implementing auto casting and gradient scaling, and clipping gradients, you’re still facing NaN loss values.

This issue could stem from a few different sources, but it’s often related to numerical instability caused by the nature of floating-point arithmetic and the specific operations involved in your model and loss functions.

Here are some suggestions to address this problem:

  1. Scaling of Losses: Since NaN losses often occur due to extremely large or small values, consider scaling your loss terms appropriately. You might try scaling down your loss terms by a factor to prevent numerical instability.

  2. Gradient Clipping: Although you’ve already applied gradient clipping, try adjusting the clipping threshold or algorithm to ensure that gradients are not exploding.

  3. Optimizer Configuration: Verify that the optimizer configurations are appropriate for mixed precision training. Ensure that the optimizer’s internal calculations are compatible with float16.

  4. Loss Function: Double-check your loss function implementations, especially for the GAN loss. NaN losses can sometimes stem from improper loss calculations or divergent loss functions.

  5. Data Preprocessing: Ensure that your input data is properly preprocessed, and there are no NaN or infinite values in your dataset.

  6. Debugging NaN Propagation: You might want to debug and trace the origin of NaN values. Temporarily remove parts of your code or simplify your model to identify where NaN values are first introduced.

  7. Model Initialization: NaN issues can also arise from improper initialization of model weights. Check if your model parameters are initialized correctly.

  8. Community Support: Consider reaching out to the PyTorch Lightning community or forums for additional insights. Others might have encountered similar issues and could offer valuable advice or solutions.

Given the complexity of GAN training and mixed precision techniques, debugging NaN issues can be challenging. However, thorough testing and systematic debugging steps should help identify and resolve the underlying problem.

Thanks for the advice, but I am confused because it is so similar to the chat gpt.
This did not work as a solution.

And indeed @sally2’s answer sounds plausible, but is confusing.

E.g.

doesn’t make sense since the GradScaler will scale down its scaling factor automatically if overflows in the gradient calculation is detected. Since you are describing an issue where the loss itself is already a NaN, changing the scaling factor of the gradient scaling won’t change anything.

Also, no idea what this means.

@sally2 please spend at least a bit of time in reviewing chat bot answers, as these are typically creating noise only and users are certainly able to use ChatGPT themselves.

@Danny_Kim you could try to isolate the operation causing the NaN values e.g. via this small util.:

import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
import itertools
import warnings
import torchvision.models as models
import random

# adapted from https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py

class Lit:
    def __init__(self, s):
        self.s = s

    def __repr__(self):
        return self.s


def fmt(t: object, print_stats=False) -> str:
    if isinstance(t, torch.Tensor):
        s = f"torch.tensor(..., size={tuple(t.shape)}, dtype={t.dtype}, device='{t.device}')"
        if print_stats:
            s += f" [with stats min={t.min()}, max={t.max()}, mean={t.mean()}]"
        return Lit(s)
    else:
        return t


class NaNErrorMode(TorchDispatchMode):
    def __init__(
        self, enabled=True, raise_error=False, print_stats=True, print_nan_index=False
    ):
        self.enabled = enabled
        # warning or error
        self.raise_error = raise_error
        # print min/max/mean stats
        self.print_stats = print_stats
        # print indices of invalid values in output
        self.print_nan_index = print_nan_index

    def __torch_dispatch__(self, func, types, args, kwargs):
        out = func(*args, **kwargs)
        if self.enabled:
            if isinstance(out, torch.Tensor):
                if not torch.isfinite(out).all():
                    # fmt_partial = partial(fmt, self.print_stats)
                    fmt_lambda = lambda t: fmt(t, self.print_stats)
                    fmt_args = ", ".join(
                        itertools.chain(
                            (repr(tree_map(fmt_lambda, a)) for a in args),
                            (
                                f"{k}={tree_map(fmt_lambda, v)}"
                                for k, v in kwargs.items()
                            ),
                        )
                    )
                    msg = f"NaN outputs in out = {func}({fmt_args})"
                    if self.print_nan_index:
                        msg += f"\nInvalid values detected at:\n{(~out.isfinite()).nonzero()}"
                    if self.raise_error:
                        raise RuntimeError(msg)
                    else:
                        warnings.warn(msg)

        return out

## warning example
model = models.resnet18().cuda()

for i in range(1000):
    # randomly set weights to NaNs to trigger warning / error
    if torch.rand(1) < 0.05:
        name, param = random.choice(list(dict(model.named_parameters()).items()))

        print(f"setting first weight value of {name} to NaN")
        with torch.no_grad():
            param.view(-1)[0].copy_(torch.tensor(float("NaN")))

    with NaNErrorMode(
        enabled=True, raise_error=False, print_stats=True, print_nan_index=False
    ):
        out = model(torch.randn(1, 3, 224, 224).cuda())
    print(f"iter: {i}, out.sum: {out.sum()}")
    if not torch.isfinite(out).all():
        break

# Example output
# iter: 0, out.sum: 30.91858673095703
# iter: 1, out.sum: 30.042282104492188
# iter: 2, out.sum: 28.941993713378906
# iter: 3, out.sum: 30.45046043395996
# iter: 4, out.sum: 30.246807098388672
# iter: 5, out.sum: 28.838367462158203
# iter: 6, out.sum: 28.77951431274414
# iter: 7, out.sum: 28.38872718811035
# iter: 8, out.sum: 28.200435638427734
# iter: 9, out.sum: 28.041297912597656
# iter: 10, out.sum: 28.185876846313477
# iter: 11, out.sum: 30.179176330566406
# iter: 12, out.sum: 28.98915672302246
# iter: 13, out.sum: 28.568819046020508
# iter: 14, out.sum: 29.45282745361328
# iter: 15, out.sum: 28.49593734741211
# iter: 16, out.sum: 30.590713500976562
# iter: 17, out.sum: 28.74820899963379
# iter: 18, out.sum: 28.62137794494629
# iter: 19, out.sum: 30.686201095581055
# iter: 20, out.sum: 29.738740921020508
# setting first weight value of layer3.1.conv1.weight to NaN
# iter: 21, out.sum: nan
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.convolution.default(torch.tensor(..., size=(1, 256, 14, 14), dtype=torch.float32, device='cuda:0') [with stats min=0.0, max=5.781907558441162, mean=0.5651370286941528], torch.tensor(..., size=(256, 256, 3, 3), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.relu_.default(torch.tensor(..., size=(1, 256, 14, 14), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan])
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.detach.default(torch.tensor(..., size=(1, 256, 14, 14), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan])
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.convolution.default(torch.tensor(..., size=(1, 256, 14, 14), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], torch.tensor(..., size=(256, 256, 3, 3), dtype=torch.float32, device='cuda:0') [with stats min=-0.14507625997066498, max=0.1386573612689972, mean=4.152135443291627e-05], None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.add_.Tensor(torch.tensor(..., size=(1, 256, 14, 14), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], torch.tensor(..., size=(1, 256, 14, 14), dtype=torch.float32, device='cuda:0') [with stats min=0.0, max=5.781907558441162, mean=0.5651370286941528])
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.convolution.default(torch.tensor(..., size=(1, 256, 14, 14), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], torch.tensor(..., size=(512, 256, 3, 3), dtype=torch.float32, device='cuda:0') [with stats min=-0.09479710459709167, max=0.10371077060699463, mean=2.4601764380349778e-05], None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1)
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.relu_.default(torch.tensor(..., size=(1, 512, 7, 7), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan])
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.detach.default(torch.tensor(..., size=(1, 512, 7, 7), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan])
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.convolution.default(torch.tensor(..., size=(1, 512, 7, 7), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], torch.tensor(..., size=(512, 512, 3, 3), dtype=torch.float32, device='cuda:0') [with stats min=-0.11115863174200058, max=0.10481023788452148, mean=-5.145858267496806e-06], None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.convolution.default(torch.tensor(..., size=(1, 256, 14, 14), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], torch.tensor(..., size=(512, 256, 1, 1), dtype=torch.float32, device='cuda:0') [with stats min=-0.2665204107761383, max=0.27854788303375244, mean=-0.0002086192835122347], None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1)
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.add_.Tensor(torch.tensor(..., size=(1, 512, 7, 7), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], torch.tensor(..., size=(1, 512, 7, 7), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan])
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.convolution.default(torch.tensor(..., size=(1, 512, 7, 7), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], torch.tensor(..., size=(512, 512, 3, 3), dtype=torch.float32, device='cuda:0') [with stats min=-0.10145102441310883, max=0.10478692501783371, mean=-9.899334145302419e-06], None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.convolution.default(torch.tensor(..., size=(1, 512, 7, 7), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], torch.tensor(..., size=(512, 512, 3, 3), dtype=torch.float32, device='cuda:0') [with stats min=-0.11157014220952988, max=0.10770585387945175, mean=4.783814802067354e-06], None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.mean.dim(torch.tensor(..., size=(1, 512, 7, 7), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], [-1, -2], True)
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.view.default(torch.tensor(..., size=(1, 512, 1, 1), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], [1, 512])
#   warnings.warn(msg)
# /tmp/ipykernel_954439/4225324741.py:73: UserWarning: NaN outputs in out = aten.addmm.default(torch.tensor(..., size=(1000,), dtype=torch.float32, device='cuda:0') [with stats min=-0.043944377452135086, max=0.04383780434727669, mean=-0.00139328942168504], torch.tensor(..., size=(1, 512), dtype=torch.float32, device='cuda:0') [with stats min=nan, max=nan, mean=nan], torch.tensor(..., size=(512, 1000), dtype=torch.float32, device='cuda:0') [with stats min=-0.04419388994574547, max=0.044194161891937256, mean=7.266089232871309e-05])

I’ve used @albanD’s example to add a debug mode to check for NaN outputs.

Alternatively, you could also try to use bfloat16 as the range would be larger and overflows could be avoided (assuming an overflow is causing the issue).

1 Like

@ptrblck
Fundamentally, updating the overall libraries and using bf16 resolved the issue.
Precision 16 and mixed precision still result in NaN errors, bue bf16 can be used as an alternative.
It was helpful to instantly check for NaN outputs using NaNErrorMode.
Thank you for your detailed and excellent response.

I had the same problem a while back. I remember that the problem for me was due to the use of Adam(W): the default epsilon value is 1e-8, which is out of precision and produces a nan value. You can avoid this by using 1e-7 (eps = 1e-7 if mixed_precision else 1e-8). But AMP caused a performance drop for me, so I ended up not using it.

Explanation: 1e-8 with fp16 caused the denominator to become 0 → division by 0 results in nan.