Hello,
I am experiencing issues applying Precision 16 in PyTorch Lightning.
I am optimizing the Generator and Discriminator using net_G_A
and net_D_A
, and optimizing patchNCELoss
using net_F_A
.
I have confirmed on documents that manual backward is essential when using multi-optimizers, and the code runs without issues with precision 32.
However, when I set precision=16
in the trainer to apply precision 16, I encounter a problem where the loss values become NaN. This NaN issue always occurs in the GAN loss part. To address this, I have tried the following:
- Applying
float()
to the data and the resulting loss values when calculating the GAN loss. - Manually implementing auto casting and gradient scaling in the
training_step
. - Clipping gradients.
Despite trying all these methods, I couldn’t avoid the NaN values in the loss.
Is there anything I might be missing in this situation?
Thank you very much for reading the long question!
class ProposedSynthesisModule(BaseModule_AtoB):
def __init__(
self,
netG_A: torch.nn.Module, # Generator
netD_A: torch.nn.Module, # Discriminator
netF_A: torch.nn.Module, # For PatchNCELoss
optimizer,
params,
scheduler,
*args,
**kwargs: Any
):
super().__init__(params, *args, **kwargs)
# assign generator
self.netG_A = netG_A
self.netD_A = netD_A
self.netF_A = netF_A
self.save_hyperparameters(logger=False)
self.automatic_optimization = False # perform manual
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.optimizer = optimizer
self.params = params
self.scheduler = scheduler
# assign contextual loss
style_feat_layers = {
"conv_2_2": 1.0,
"conv_3_2": 1.0,
"conv_4_2": 1.0,
"conv_4_4": 1.0
}
# loss function
self.criterionContextual = Contextual_Loss(style_feat_layers)
self.criterionGAN = GANLoss(gan_type='lsgan')
self.criterionNCE = PatchNCELoss(False, nce_T=0.07, batch_size=params.batch_size)
# PatchNCE specific initializations
self.nce_layers = [0,2,4,6] # range: 0~6
self.flip_equivariance = params.flip_equivariance
# self.flipped_for_equivariance = False
def backward_G(self, real_a, real_b, fake_b, lambda_style, lambda_nce):
pred_fake = self.netD_A(fake_b.detach())
loss_gan = self.criterionGAN(pred_fake, True)
assert not torch.isnan(loss_gan).any(), "GAN Loss is NaN"
## Contextual loss
loss_style_B = self.criterionContextual(real_b, fake_b)
loss_style = loss_style_B * lambda_style
assert not torch.isnan(loss_style).any(), "Contextual Loss is NaN"
# PatchNCE loss (real_a, fake_b)
n_layers = len(self.nce_layers)
feat_b = self.netG_A(fake_b, real_a, self.nce_layers, encode_only=True)
flipped_for_equivariance = np.random.random() < 0.5
if self.flip_equivariance and flipped_for_equivariance:
feat_b = [torch.flip(fb, [3]) for fb in feat_b]
feat_a = self.netG_A(real_a, real_b, self.nce_layers, encode_only=True)
feat_a_pool, sample_ids = self.netF_A(feat_a, 256, None)
feat_b_pool, _ = self.netF_A(feat_b, 256, sample_ids)
total_nce_loss = 0.0
for f_a, f_b in zip(feat_b_pool, feat_a_pool):
loss = self.criterionNCE(f_a, f_b) * lambda_nce
total_nce_loss += loss.mean()
loss_nce = total_nce_loss / n_layers
assert not torch.isnan(loss_nce).any(), "NCE Loss is NaN"
loss_G = loss_gan + loss_style + loss_nce
assert not torch.isnan(loss_G).any(), "Total Loss is NaN"
return loss_G, loss_gan, loss_style, loss_nce
def training_step(self, batch: Any, batch_idx: int):
optimizer_G_A, optimizer_D_A, optimizer_F_A = self.optimizers()
real_a, real_b, fake_b = self.model_step(batch)
with optimizer_G_A.toggle_model():
# with optimizer_F_A.toggle_model():
loss_G, loss_gan, loss_style, loss_nce = self.backward_G(real_a, real_b, fake_b, self.params.lambda_style, self.params.lambda_nce)
self.manual_backward(loss_G)
self.clip_gradients(
optimizer_G_A, gradient_clip_val=0.5, gradient_clip_algorithm="norm"
)
self.clip_gradients(
optimizer_F_A, gradient_clip_val=0.5, gradient_clip_algorithm="norm"
)
optimizer_G_A.step()
optimizer_F_A.step()
optimizer_G_A.zero_grad()
optimizer_F_A.zero_grad()
# self.loss_G = loss_G.detach() * 0.1 + self.loss_G * 0.9
self.log("G_loss", loss_G.detach(), prog_bar=True)
self.log("loss_gan", loss_gan.detach(), prog_bar=True)
self.log("loss_style", loss_style.detach(), prog_bar=True)
self.log("loss_nce", loss_nce.detach(), prog_bar=True)
with optimizer_D_A.toggle_model():
loss_D_A = self.backward_D_A(real_b, fake_b)
self.manual_backward(loss_D_A)
self.clip_gradients(
optimizer_D_A, gradient_clip_val=0.5, gradient_clip_algorithm="norm"
)
optimizer_D_A.step()
optimizer_D_A.zero_grad()
self.log("Disc_A_Loss", loss_D_A.detach(), prog_bar=True)
def configure_optimizers(self):
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
Examples:
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
"""
optimizers = []
schedulers = []
optimizer_G_A = self.hparams.optimizer(params=self.netG_A.parameters())
optimizers.append(optimizer_G_A)
optimizer_D_A = self.hparams.optimizer(params=self.netD_A.parameters())
optimizers.append(optimizer_D_A)
optimizer_F_A = self.hparams.optimizer(params=self.netF_A.parameters())
optimizers.append(optimizer_F_A)
if self.hparams.scheduler is not None:
scheduler_G_A = self.hparams.scheduler(optimizer=optimizer_G_A)
schedulers.append(scheduler_G_A)
scheduler_D_B = self.hparams.scheduler(optimizer=optimizer_D_A)
schedulers.append(scheduler_D_B)
scheduler_F_A = self.hparams.scheduler(optimizer=optimizer_F_A)
schedulers.append(scheduler_F_A)
return optimizers, schedulers
class GANLoss(nn.Module):
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0, reduction='mean'):
super(GANLoss, self).__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss(reduction=reduction)
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss(reduction=reduction)
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan_softplus':
self.loss = self._wgan_softplus_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
elif self.gan_type == 'swd':
self.loss = self._slicedWassersteinDistance_loss
# elif self.gan_type == 'bce':
# self.loss = nn.BCELoss()
else:
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
def forward(self, input, target_is_real, is_disc=False):
target_label = self.get_target_label(input, target_is_real)
if self.gan_type == 'swd' and not is_disc:
swd_loss = self.loss(input, target_label)
return swd_loss * self.loss_weight
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight