Hi
I’m currently trying to train a GAN on images with mixed precision. On important part in this GAN training is R1 regularization (https://arxiv.org/pdf/1801.04406.pdf, paper implementation: GAN_stability/train.py at master · LMescheder/GAN_stability · GitHub).
When I train the model using amp the gradient scales scale value gets halved every step and quickly ends up at a scale of 0.
The model runs fine without mixed precision as well as with mixed precision and R1 disabled.
I oriented my implementation on the gradient penalty example from the AMP documentation since its quite similar (Automatic Mixed Precision examples — PyTorch 1.11.0 documentation).
This is the code block that is causing the issue when AMP is enabled:
real_loss = ...
real_logits = ...
if r1_reg:
# Scales the loss for autograd.grad's backward pass, producing scaled_grads
scaled_real_logits = scaler.scale(real_logits.sum())
scaled_r1_grads = torch.autograd.grad(outputs=[scaled_real_logits], inputs=[real], create_graph=True)[0]
# Creates unscaled grads before computing the penalty. scaled_grads are
# not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
inv_scale = 1. / scaler.get_scale()
r1_grads = scaled_r1_grads * inv_scale
with torch.cuda.amp.autocast():
r1_penalty = r1_grads.square().sum([1, 2, 3])
loss_r1 = torch.mean(r1_penalty * (r1_gamma / 2))
real_loss += loss_r1
scaler.scale(real_loss).backward()
thanks in advance for any clue!
Best regards,
Andreas