Network parameters becoming inf after first optimization step

I’ve been trying to train a pix2pix model in pytorch as one of my first projects using this framework - I’m a new user, I’ve been using tf/keras before this.

I’m using generator/discriminator network definition code from this repository and similar training code.

I’m encountering a problem with some weights for my generator network becoming inf immediately during training when I call the optimizer’s step() function in the very first training iteration.
If I set torch.autograd.set_detect_anomaly(True) I will get the error message: RuntimeError: Function 'CudnnConvolutionBackward' returned nan values in its 0th output.

I’m training on images loaded as half-precision floats and scaled to [0, 1] by dividing pixels by 255.
Before training, I ran a test input through the generator and discriminator to check that all was as expected with randomly initialized weights. Both generator and discriminator output sound outputs before training - no nan or inf - it is only after the first call of the generator optimizer’s step().

My problem is similar to what’s described in this thread, except I didn’t find answers there.

Here is my training code:

# G: U-net generator model, D: PatchGAN discriminator
G = models.GeneratorUNet()
D = models.Discriminator()

models.init_weights(G) # models: external import, initialize weights with normal(0.0, 0.02)
models.init_weights(D)

adv_loss = torch.nn.MSELoss()
pixelwise_loss = torch.nn.L1Loss()
pixelwise_weight = 100

assert torch.cuda.is_available()

G = G.half().cuda()
D = D.half().cuda()
adv_loss.cuda()
pixelwise_loss.cuda()
device = torch.device("cuda:0")
Tensor = torch.cuda.HalfTensor
torch.autograd.set_detect_anomaly(True)

optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(b1, b2))

def check_state(model):
    for key, val in model.state_dict().items():
        check_numeric(key, val)
        
def check_numeric(name, val):
    if val.isnan().any():
        raise ValueError(f'Found nan values in network element {name}:\n{val}')
    if val.isinf().any():
        raise ValueError(f'Found inf values in network element {name}:\n{val}')

def check_gradients(model):
    for name, param in model.named_parameters():
        if param.grad is not None:
            if not torch.isfinite(param.grad).all():
                raise ValueError(f'{name}has a non-finite gradient')
        else:
            warnings.warn(f'{name} if None')

n_epoch = 5
for epoch_i in range(n_epoch):
    for batch_i, batch in enumerate(ds_iterator):
        inp, tar = prepare_batch(batch)
        check_numeric('batch-input', inp); check_numeric('batch-target', tar)

        # PatchGAN discriminator labels plus some noise
        label_tar = Tensor(Variable(
            torch.ones(d_output_shape, requires_grad=False, device=device, dtype=torch.float16) + \
                        0.05 * torch.rand(d_output_shape, device=device, dtype=torch.float16, requires_grad=False)))
        label_generated = Tensor(Variable(
            torch.zeros(d_output_shape, requires_grad=False, device=device, dtype=torch.float16) + \
                        0.05 * torch.rand(d_output_shape, device=device, dtype=torch.float16, requires_grad=False)))

        # Generator train step 
        optimizer_G.zero_grad()

        generated = G(inp)
        d_fake_prediction = D(inp, generated)
        check_numeric('discriminator prediction', d_fake_prediction)
        print(d_fake_prediction)
        g_adv_loss = adv_loss(d_fake_prediction, label_tar)
        check_numeric('Generator adversarial loss', g_adv_loss)
        g_pixelwise_loss = pixelwise_loss(generated, tar)
        check_numeric('Generator pixelwise loss', g_pixelwise_loss)
        g_tot_loss =  g_adv_loss + pixelwise_loss_weight * g_pixelwise_loss
        check_numeric('Generator loss', g_tot_loss)
        g_tot_loss.backward()

        # ==== Update generator - the problematic step
        check_gradients(G)
        check_state(G)
        optimizer_G.step()
        check_state(G)


        # Discriminator train step 
        optimizer_D.zero_grad()

        d_inp_prediction = D(inp, tar)
        d_real_loss = adv_loss(d_inp_prediction, label_tar)

        d_genenerated_prediction = D(inp, generated.detach())
        d_generated_loss = adv_loss(d_genenerated_prediction, label_generated)

        d_tot_loss = d_generated_loss + d_real_loss

        d_tot_loss.backward()
        check.check_gradients(D)
        check_state(D)
        optimizer_D.step()
        check_state(D)

The check_state(G) function finds inf values in the generator’s down1.model.0 weights, which is the very first Conv2d layer in the U-net, immediately after the update:

--------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-13-08a1061796d8> in <module>()
     45         check_state(G)
     46         optimizer_G.step()
---> 47         check_state(G)
     48 
     49 

<ipython-input-12-53bce0ba077d> in check_state(model)
      1 def check_state(model):
      2     for key, val in model.state_dict().items():
----> 3         check_numeric(key, val)
      4 
      5 

<ipython-input-12-53bce0ba077d> in check_numeric(name, val)
      8         raise ValueError(f'Found nan values in network element {name}:\n{val}')
      9     if val.isinf().any():
---> 10         raise ValueError(f'Found inf values in network element {name}:\n{val}')
     11 
     12 def check_gradients(model):

ValueError: Found inf values in network element down1.model.0.weight:
tensor([[[[-0.0071,  0.0104, -0.0214, -0.0032],
          [ 0.0017, -0.0107,  0.0184, -0.0110],
          [ 0.0123, -0.0030,  0.0021,  0.0219],
          [ 0.0100, -0.0283,  0.0221, -0.0306]]],


        [[[ 0.0104, -0.0177, -0.0127,  0.0025],
          [ 0.0034,  0.0259, -0.0134, -0.0453],
          [-0.0383, -0.0414, -0.0170, -0.0095],
          [-0.0069,  0.0054,  0.0053, -0.0057]]],


        [[[-0.0145,  0.0374,  0.0151,     inf],
          [ 0.0081, -0.0072,  0.0039, -0.0157],
          [-0.0076,  0.0373,  0.0077, -0.0062],
          [ 0.0081,  0.0218,  0.0273, -0.0279]]],


        ...,


        [[[-0.0335, -0.0025, -0.0200, -0.0047],
          [-0.0007,  0.0229,  0.0079,  0.0400],
          [-0.0125,  0.0203, -0.0186, -0.0202],
          [-0.0139,  0.0313, -0.0167,  0.0044]]],


        [[[-0.0120, -0.0248,  0.0106,  0.0044],
          [-0.0176, -0.0036, -0.0154,  0.0140],
          [ 0.0333, -0.0048,  0.0269,  0.0060],
          [ 0.0170, -0.0023,  0.0054,  0.0044]]],


        [[[ 0.0206, -0.0218, -0.0184, -0.0201],
          [ 0.0319, -0.0241,  0.0281,  0.0074],
          [ 0.0211, -0.0055,  0.0014, -0.0230],
          [-0.0184,  0.0142,  0.0044,  0.0003]]]], device='cuda:0',
       dtype=torch.float16)

Here is what’s in down1.model.0:

GeneratorUNet(
  (down1): UNetDown(
    (model): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
  )
  ( ... )

I logged the norms of the generator’s gradients using the following code:

        g_gradient_norms = sorted([param.grad.data.norm(2).item()
                        for param in G.parameters() 
                        if param.grad is not None])

and found that the largest gradient l2-norm was ~5.0 before the update.
I am hoping that some people with more experience with pytorch might be able to see something I’m missing here, because I haven’t been able to figure out what’s causing inf to show up and break my network. Underflow errors due to gradients being too small? Numerical foot-guns I should be aware of?

Using float16 manually might be a bit tricky, as your parameters and activations might under/overflow.
If your model is working fine in float32, I would recommend to check out our native mixed-precision training via torch.cuda.amp.

I was using float32 without issues before - float16 is for the GPU I have available. I have been using cuda on WSL. Is amp supported on WSL? (Either way, I’ll try it out.)

Checking back in an hour later: amp seems to be working fine on WSL and solving the problem so far. Thank you! Really glad that it worked.