Hi, I’m getting NaN values after the first backprop. The value of y_kld is of order 1e-8. Is this the issue? In the below code q_y is an intermediate output in my network. Kindly let me know how to fix this issue. (Self.class_diff is a numpy constant)
x_true1 = x_true1.to(self.device)
x_recon, mu, logvar, z,cat_logit = self.VAE(x_true1)
vae_recon_loss = recon_loss(x_true1, x_recon)
q_y= F.softmax(cat_logit,dim=1)
log_q_y=torch.log(q_y)
vae_kld = kl_divergence(mu, logvar)
y_kld= (torch.sum(torch.mul(q_y,(log_q_y-self.class_diff)),1)).mean()
D_z = self.D(z)
vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()
Vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss +y_kld