Gradient clipping not working

I’m building a model similar to a GAN, the sigmoid layer of discriminator outputs 1 when the discriminator gets better. This causes the loss function log(1 - D(Z)) to become -inf. This is the code for my discriminator:

class Discriminator(nn.Module):
    def __init__(self, hidden_size, latent_dim):
        super(Discriminator, self).__init__()

        self.dim_h = hidden_size
        self.n_z = latent_dim

        self.main = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h),
            nn.Linear(self.dim_h, self.dim_h),
            nn.Linear(self.dim_h, 1),

    def forward(self, x):
        x = self.main(x)
        return x

And this is the training the discriminator portion

 X_train, y_train, length = next(data_generator(data, y, word_to_dict, batch_size = batch_size))
 X_train = X_train.cuda()
 encoded_z, _ , _, _= model.encode(X_train)
 z = torch.randn_like(encoded_z)
 z = z.cuda()
 d_z = critic(z)
 d_z_hat = critic(encoded_z)
 d_z_loss = lambd * torch.log(d_z).mean()
 d_z_hat_loss = lambd * torch.log(1 - d_z_hat).mean()
 loss = -(d_z_hat_loss + d_z_loss)
 torch.nn.utils.clip_grad_norm_(critic.parameters(), 0.25)

To avoid -inf loss to cause instability during backpropagation, I am clipping the gradient by norm before d_optim.step().

I have also tried using torch.nn.utils.clip_grad_value_(parameters, clip_value) to solve the -inf problem.

Both the functions did not work correctly, is there anything I am missing regarding this?

The gradient norm clipping wouldn’t work, since multiplying a +/-Inf gradient with the scale factor won’t change the gradient (used here).

While clipping the gradient with values would work for +/-Inf values, unfortunately the +/-Inf loss might create NaN gradients as seen here:

model = nn.Linear(1, 1)
x = torch.randn(1, 1)
out = model(x)

loss = torch.log(out - out)
> tensor([[-inf]], grad_fn=<LogBackward>)

> tensor([[nan]])

I think the right approach would be to avoid creating the invalid loss values and e.g. add a small eps value to the loss calculation.

1 Like