How can I compute this function in a way that handles gradients correctly?

``````def f(x):
``````

This issue causes an incorrect `nan` gradient at `x == 1`:

``````x = torch.tensor(1., requires_grad=True)
y = f(x)
print(y)
y.backward()
``````

I tried using `masked_scatter` but it also doesn’t work:

``````def f(x):
return x.masked_scatter(x < 0, x / (1 - x))
``````
If your problem is related to the presence of NaNs, I think you could:

1. use an `if statement` to avoid `x == 1`, you could set a smaller value for `x`, for instance `x = 0.98`;
2. before returning the value, you could check the presence of NaNs: you could create a variable `function = torch.where(x > 0, x, x / (1 - x))`, then use `torch.nan_to_num(function, nan = value)`.

I don’t know if this solutions are the best, but I guess they’re worth trying.
Let me know if you're able to solve the problem

