The nan value is output while training the discriminator

I’ve been sticking the Discriminator code from Projected GAN into my work, but I’m having issues with the value ‘nan’ popping up sometimes during training. (link: projected-gan/pg_modules/discriminator.py at main · autonomousvision/projected-gan · GitHub)

The code in the part where the nan value occurs looks like this

from models.pg_modules.discriminator import ProjectedDiscriminator
self._discriminator = ProjectedDiscriminator(interp224=False, diffaug=args.diff_aug, backbone_kwargs={'im_res': args.train_p_size}).to(self.device, non_blocking=True)
self._discriminator.feature_network.requires_grad_(False)
from torch.cuda.amp import autocast
torch.autograd.set_detect_anomaly(True)
...
with autocast():
    if self.ab_iter > ADV_LOSS_START_ITER:
        # gan loss
        for param in self._discriminator.parameters():
            param.requires_grad = False

        d_pred_fake = torch.zeros(1,).to(self.device)
        d_pred_real = torch.zeros(1,).to(self.device)
        p_loss = torch.zeros(1,).to(self.device)
        
        recon_sample_mask = [False if (z % (2 ** self.args.exp) == 0) else True for z in range(I_r.shape[2])]

        I_r_recon_sample = torch.flatten(I_r.permute(2,0,1,3,4)[recon_sample_mask],0,1)
        I_gt_recon_sample = torch.flatten(I_gt.permute(2,0,1,3,4)[recon_sample_mask],0,1)

        d_pred_fake = self._discriminator(I_r_recon_sample, None)
        d_pred_real = self._discriminator(I_gt_recon_sample, None)

        p_loss =  torch.mean(self.perceptual_loss(I_r_recon_sample, I_gt_recon_sample))

When I trace to the PDB, the nan values occur in the entire vectors ‘d_pred_fake’ and ‘d_pred_real’.

So I saved the “self._discriminator.state_dict()” and checked them one by one and observed the nan values in the following layers:
discriminator.mini_discs.0.main.0.main.1.running_mean: some
discriminator.mini_discs.0.main.0.main.1.running_var: some
discriminator.mini_discs.0.main.1.main.1.running_mean: All
discriminator.mini_discs.0.main.1.main.1.running_var: All
discriminator.mini_discs.0.main.2.main.1.running_mean: All
discriminator.mini_discs.0.main.2.main.1.running_var: All
discriminator.mini_discs.0.main.3.main.1.running_mean: All
discriminator.mini_discs.0.main.3.main.1.running_var: All

Here’s an example:
(‘discriminator.mini_discs.0.main.0.main.1.running_mean’, tensor([ 359.3079, -298.1175, -259.4460, 251.9945, nan, -239.5377, 369.6451, -256.0357, -167.9374, -180.5169, -218.9411, 337.1237, 360.0313, 359.8208, 344.4244, 344.2194, -291.9659, 340.6350, -223.7873, 365.7599, 340.1869, 331.7911, 349.3535, 362.2325, -167.6774, 368.3821, 360.9123, 347.4428, 360.9276, 348.5717, 361.1050, 333.7169, 358.0670, -201.1755, 246.7143, 350.7696, 351.4091, nan, 298.4184, -282.5826, 348.0694, -259.4500, -276.4198, -115.4993, -161.2602, -277.1658, 333.3922, -253.5999, 349.4794, -207.5786, -154.3773, 365.1758, 362.2701, 361.7145, 280.3458, -196.3884, 340.4942, nan, -242.9723, -235.8712, 348.7751, -159.7664, 359.7401, -206.8569], device=‘cuda:0’)), (‘discriminator.mini_discs.0.main.0.main.1.running_var’, tensor([29096068.0000, 21208318.0000, 14500196.0000, 10362571.0000, nan, 14853094.0000, 37487592.0000, 16709822.0000, 6030352.0000, 9105914.0000, 9564074.0000, 31026200.0000, 29738310.0000, 35907732.0000, 27829708.0000, 29373250.0000, 19012818.0000, 33209136.0000, 13472100.0000, 30584564.0000, 34549508.0000, 26184250.0000, 28431476.0000, 34256568.0000, 5768366.0000, 27793396.0000, 35548128.0000, 33561796.0000, 35050416.0000, 33522100.0000, 33429276.0000, 29630088.0000, 38192136.0000, 11925581.0000, 9155247.0000, 32848876.0000, 35051712.0000, nan, 22925550.0000, 14195619.0000, 30338050.0000, 12329162.0000, 16421357.0000, 5465982.5000, 8515777.0000, 15404568.0000, 20426380.0000, 13397992.0000, 30361580.0000, 12025232.0000, 7802527.0000, 26933310.0000, 30810768.0000, 37473368.0000, 14592150.0000, 9513152.0000, 33451412.0000, nan, 15610613.0000, 12374429.0000, 28417884.0000, 5863875.0000, 36724032.0000, 11243141.0000], device=‘cuda:0’)),

I could have simply assumed that the weight of the pretrained D was the problem, but these nan values don’t always appear and if I’m lucky, the training ends without nan occurring until the 100,000 iterations I set.
(In the example I showed above, it happened at the 20,000th iteration out of 100,000.)

In fact, I can see it in the debugger when I run

(Pdb++) p self._discriminator(I_gt_recon_sample, None)

command in the debugger, only 1 out of 5 or 6 times does the whole value come out as nan, the rest of the time it doesn’t.

I checked the dtype of the discriminator parameter and the values are as follows:

(Pdb++) for param in self._discriminator.parameters(): print(param.dtype)

torch.float32
torch.float32

I am not sure what on earth is causing this to happen.
I am desperate for your help. How can I solve this problem of getting NAN sometimes?