Nan gradient despite masking

How can I compute this function in a way that handles gradients correctly?

def f(x):
    return torch.where(x > 0, x, x / (1 - x))

This issue causes an incorrect nan gradient at x == 1:

x = torch.tensor(1., requires_grad=True)
y = f(x)
print(y)
y.backward()
print(x.grad)

I tried using masked_scatter but it also doesn’t work:

def f(x):
    return x.masked_scatter(x < 0, x / (1 - x))
1 Like

If your problem is related to the presence of NaNs, I think you could:

  1. use an if statement to avoid x == 1, you could set a smaller value for x, for instance x = 0.98;
  2. before returning the value, you could check the presence of NaNs: you could create a variable function = torch.where(x > 0, x, x / (1 - x)), then use torch.nan_to_num(function, nan = value).

I don’t know if this solutions are the best, but I guess they’re worth trying.
Let me know if you’re able to solve the problem :smiley:

1 Like

Solution is to use Masking() layers available in keras with mask_value=0 . This is because when using empty vectors they are calculated into the loss, by using Masking() , as outlined by keras the padding vectors are skipped and not included.

1 Like