WGAN-GP with Mixed Precision forces Scaler to 0


I’m trying to implement WGAN-GP. Without mixed precision it works perfectly fine, but with it the critic’s scaled gradients contain NaNs, which causes the scaler to shrink its scale until it vanishes. It hangs around the 64 to 16 for a while before falling off. If I move the ‘autograd()’ for the gradient penalty outside of the ‘with autocast’, the scaler seems to stabilize, but the training time goes up with 20%.

Here is some info about the training.

  • Critic scaler first quickly declines to 64/32/16 where it stays for a few minutes.
  • Critic scaler reduces then quickly to 2.8E-45 after which it becomes 0.
  • Critic and penalty gradients almost never reach 100 or -100 after scaling.
  • Generator gradients go past 100000 with a scale which stays at 65536 regardless of the GP position.
  • Disabling autograd fixes it.
  • Adam eps and lr to 10E-3 and weight decay to 10E-2. BN eps to 10E-3
  • critic ranges: loss [-10, 0], input[-2, 2], output[-20, 20], params[-2, 2], penalty[-0.2, -0.2]
  • It happened around epoch 130.
  • Moving GP’s ‘autograd()’ out of ‘autocast()’ fixes scale, but takes 20% longer to train.

I’ve been pulling out hairs for a week. Could someone explain to me why this happens and whether there is an alternative to moving the GP out of ‘autocast()’ as this takes too much time? Could someone also look at whether the GP is calculated correctly (should I concat ‘interp_imgs’ and ‘X’). Should I use a separate scaler for the GP? Why are the scaled gradients of the generator allowed to become so high? Thank you!

def get_grad_penalty(crt, X, y_real, y_fake, scaler):
    batch_size, num_channels, height, width = y_real.shape
    epsilon = torch.rand((batch_size, 1, 1, 1), device=DEVICE).repeat(1, num_channels, height, width)
    interp_imgs = epsilon * y_real + (1-epsilon) * y_fake
    mixed_scores = crt(X, interp_imgs)
    scaled_grads = torch.autograd.grad(
    return scaled_grads.view(scaled_grads.shape[0], -1)
def train(dataloader, gen_model, gen_optimizer, gen_scaler, crt_model, crt_optimizer, crt_scaler):

    size = len(dataloader.dataset)
    batch_count = len(dataloader)
    for batch, (X, y) in enumerate(dataloader):
        y_fake = gen_model(X)

        # Train the critic.
        for _ in range(5):
            with torch.autocast(device_type=DEVICE):
                scaled_grads = get_grad_penalty(crt_model, X, y, y_fake, crt_scaler)  # Moving this line up fixes it, but slow.
                real_pred = crt_model(X, y).reshape(-1)
                fake_pred = crt_model(X, y_fake).reshape(-1)
                grads = scaled_grads / crt_scaler.get_scale()
                grad_norm = grads.norm(2, dim=1)
                grad_penalty = torch.mean((grad_norm - 1)**2)
                loss = torch.mean(fake_pred) - torch.mean(real_pred) + 20 * grad_penalty
        # Train the generator.
        with torch.autocast(device_type=DEVICE):
            fake_pred = crt_model(X, y_fake).reshape(-1)
            gen_loss = -torch.mean(fake_pred.float())