# Incorrect gradient calculation with torch.where and NaNs

I am implementing a truncated gaussian distribution in Pytorch (https://github.com/TheAeryan/stable-truncated-gaussian). In order to do that in a numerically stable manner, I need conditional code (if-then-else) for calculating the mean, variance and log-probs in one way or another, depending on the specific input values. Here is an example:

``````def _F_1(x, y):
if torch.abs(x) > torch.abs(y):
out = SequentialTruncatedGaussian._F_1(y, x)
elif torch.abs(x - y) < 1e-7:
out = SequentialTruncatedGaussian._P_1(x, y-x)
elif x <= 0 and y <= 0:
delt = SequentialTruncatedGaussian._delta(x, y)
out = (1 - delt) / (delt*erfcx(-y) - erfcx(-x))
elif x >= 0 and y >= 0:
delt = SequentialTruncatedGaussian._delta(x, y)
out = (1 - delt) / (erfcx(x) - delt*erfcx(y))
else:
delt = SequentialTruncatedGaussian._delta(x, y)
out = ((1-delt)*torch.exp(-x**2)) / (erf(y)-erf(x))

return out
``````

I am now adapting my code to make it parallel, so I need a way to insert these if-then-else conditions into Pytorch. What I do is that I compute the out values associated with each condition and, then, I use torch.where to select the correct value according to its associated condition. This is how I have implemented the _F_1 function above:

``````def _F_1(x_, y_):
# All values in tensor @x must be smaller than values in tensor @y
x = torch.where(torch.abs(x_)<=torch.abs(y_),x_,y_)
y = torch.where(torch.abs(y_)>=torch.abs(x_),y_,x_)

# Values
delt = NewParallelTruncatedGaussian._delta(x, y)
one_minus_delt = 1 - delt
out1 = NewParallelTruncatedGaussian._P_1(x, y-x)
out2 = one_minus_delt / (delt*erfcx(-y) - erfcx(-x) + EPS)
out3 = one_minus_delt / (erfcx(x) - delt*erfcx(y) + EPS)
out4 = (one_minus_delt*torch.exp(-x**2)) / (erf(y)-erf(x) + EPS)

# Conditions
x_cond, y_cond = x.detach(), y.detach() # Make sure that gradients do not go through the conditions

out1_cond = torch.abs(x_cond - y_cond) < 1e-7 # tensor([False, False, False, False])
out2_cond = torch.logical_and( torch.logical_and(x_cond<=0, y_cond<=0), torch.logical_not(out1_cond) ) # tensor([False,  True, False, False])
out3_cond = torch.logical_and( torch.logical_and(x_cond>=0, y_cond>=0), torch.logical_not(out1_cond) ) # tensor([False, False, False,  True])
out4_cond = torch.logical_and( torch.logical_and( torch.logical_not(out1_cond), torch.logical_not(out2_cond) ),
torch.logical_not(out3_cond) ) # All the other conditions must be false # tensor([ True, False,  True, False])

# Final expression
final_out = torch.where(out1_cond,out1,0) + torch.where(out2_cond,out2,0) + torch.where(out3_cond,out3,0) + torch.where(out4_cond,out4,0)

return final_out
``````

This works perfectly for the forward pass but, at the backward pass (i.e., for computing the gradients), I get NaN gradients. What happens is that some of these expressions result in NaNs or infs for some positions. Nonetheless, the values in these positions are discarded by torch.where (i.e., they are substituted by 0). Despite this, I still get NaN gradients for the final result (final_out), even though the values which result in NaN gradients are not used in calculating final_out, since torch.where discards them. In other words, I think that if I were able to substitute NaN gradients for any other value (e.g., 0), the gradient calculations would work perfectly.

Any suggestions on how to solve this issue?

Could you post an example input which would yield valid values in the first example but NaNs in the latter?

Hi Carlos!

First, this is a known issue (with no simple fix for `torch.where()`). See,
for example, github issues 68425 and 70342.

Second, I believe that the best fix is to avoid producing `nan`s (or the values
that subsequently produce `nan`s) in the forward pass, even in the â€śbranch
not takenâ€ť of the `torch.where()`. Instead, compute in their places â€śsafeâ€ť
(non-`nan`-producing) values that can be incorrect, because `torch.where()`
will successfully discard the incorrect (but non-`nan`) gradients that they

For a concrete example of how feed â€śsafeâ€ť values to the â€śbranch not takenâ€ť
in a problem analogous to yours, see this post:

Best.

K. Frank

1 Like

@KFrank, I looked into what you mentioned and, yes, that seems to be the root of my problem. In order to solve it, I first tried to use masked tensors but, unfortunately, they donâ€™t support the erfcx operation (yet). Next, I tried to implement â€śmy own versionâ€ť of masked tensors by masking the input values for each separate operation so that NaNs are not produced (as you suggested), applying the operations, and then masking the outputs again before adding up the results. This seems to work both for the forward and backward pass. Here is my new implementation of the function _F_1:

``````def _F_1(x_, y_):
# All values in tensor @x must be smaller than values in tensor @y
x = torch.where(torch.abs(x_)<=torch.abs(y_),x_,y_)
y = torch.where(torch.abs(y_)>=torch.abs(x_),y_,x_)

out1_cond = torch.abs(x - y) < 1e-7 # tensor([False, False, False, False])
out2_cond = torch.logical_and( torch.logical_and(x<=0, y<=0), torch.logical_not(out1_cond) ) # tensor([False,  True, False, False])
out3_cond = torch.logical_and( torch.logical_and(x>=0, y>=0), torch.logical_not(out1_cond) ) # tensor([False, False, False,  True])
out4_cond = torch.logical_and( torch.logical_and( torch.logical_not(out1_cond), torch.logical_not(out2_cond) ),
torch.logical_not(out3_cond) ) # All the other conditions must be false

# Mask input values for each operation
x1, y1 = where(out1_cond, x, 0), where(out1_cond, y, 0)
x2, y2 = where(out2_cond, x, -1), where(out2_cond, y, 0)
x3, y3 = where(out3_cond, x, 0), where(out3_cond, y, 1)
x4, y4 = where(out4_cond, x, 0), where(out4_cond, y, 1)

# Apply operations
delt = ParallelTruncatedGaussian._delta(x, y)
one_minus_delt = 1 - delt

out1_m = ParallelTruncatedGaussian._P_1(x1, y1-x1)
out2_m = one_minus_delt / (delt*erfcx(-y2) - erfcx(-x2))
out3_m = one_minus_delt / (erfcx(x3) - delt*erfcx(y3))
out4_m = (one_minus_delt*torch.exp(-x4**2)) / (erf(y4)-erf(x4))

out1 = where(out1_cond, out1_m, 0)
out2 = where(out2_cond, out2_m, 0)
out3 = where(out3_cond, out3_m, 0)
out4 = where(out4_cond, out4_m, 0)

# Add them up into a single tensor
final_out = out1 + out2 + out3 + out4

return final_out
``````

Also, I have acknowledged your help, @KFrank and @ptrblck, in my github repo .

1 Like