Mixed precision and r1 regularization

Hi

I’m currently trying to train a GAN on images with mixed precision. On important part in this GAN training is R1 regularization (https://arxiv.org/pdf/1801.04406.pdf, paper implementation: GAN_stability/train.py at master · LMescheder/GAN_stability · GitHub).

When I train the model using amp the gradient scales scale value gets halved every step and quickly ends up at a scale of 0.
The model runs fine without mixed precision as well as with mixed precision and R1 disabled.

I oriented my implementation on the gradient penalty example from the AMP documentation since its quite similar (Automatic Mixed Precision examples — PyTorch 1.11.0 documentation).

This is the code block that is causing the issue when AMP is enabled:

real_loss = ...
real_logits = ...

if r1_reg:
  # Scales the loss for autograd.grad's backward pass, producing scaled_grads
  scaled_real_logits = scaler.scale(real_logits.sum())
  scaled_r1_grads = torch.autograd.grad(outputs=[scaled_real_logits], inputs=[real], create_graph=True)[0]
  
  # Creates unscaled grads before computing the penalty. scaled_grads are
  # not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
  inv_scale = 1. / scaler.get_scale()
  r1_grads = scaled_r1_grads * inv_scale
  
  with torch.cuda.amp.autocast():
      r1_penalty = r1_grads.square().sum([1, 2, 3])
      loss_r1 = torch.mean(r1_penalty * (r1_gamma / 2))
      
      real_loss += loss_r1

scaler.scale(real_loss).backward()

thanks in advance for any clue!

Best regards,
Andreas