For some reason, when I train WGAN-GP with mixed precision using torch.cuda.amp package, something happens to my GradScaler for the critic. During the training, the scaler’s scale is decreasing from its usual values to very low ones (like 1e-7).
Strangely, it happens only for critic and only for WGAN-GP model. When I try LS-GAN or a simple GAN everything is ok.
So, the problem is somehow inside the GP part, but I just can’t locate it. (for some reason, fp16 grads for GP contains nans, but I have no idea why)
So, the question is: what could be a solution to such low scale values and what is the reason for it?