I am implementing a truncated gaussian distribution in Pytorch (https://github.com/TheAeryan/stable-truncated-gaussian). In order to do that in a numerically stable manner, I need conditional code (*if-then-else*) for calculating the mean, variance and log-probs in one way or another, depending on the specific input values. Here is an example:

```
def _F_1(x, y):
if torch.abs(x) > torch.abs(y):
out = SequentialTruncatedGaussian._F_1(y, x)
elif torch.abs(x - y) < 1e-7:
out = SequentialTruncatedGaussian._P_1(x, y-x)
elif x <= 0 and y <= 0:
delt = SequentialTruncatedGaussian._delta(x, y)
out = (1 - delt) / (delt*erfcx(-y) - erfcx(-x))
elif x >= 0 and y >= 0:
delt = SequentialTruncatedGaussian._delta(x, y)
out = (1 - delt) / (erfcx(x) - delt*erfcx(y))
else:
delt = SequentialTruncatedGaussian._delta(x, y)
out = ((1-delt)*torch.exp(-x**2)) / (erf(y)-erf(x))
return out
```

I am now adapting my code to make it parallel, so I need a way to *insert* these *if-then-else* conditions into Pytorch. What I do is that I compute the *out* values associated with each condition and, then, I use *torch.where* to select the correct value according to its associated condition. This is how I have implemented the *_F_1* function above:

```
def _F_1(x_, y_):
# All values in tensor @x must be smaller than values in tensor @y
x = torch.where(torch.abs(x_)<=torch.abs(y_),x_,y_)
y = torch.where(torch.abs(y_)>=torch.abs(x_),y_,x_)
# Values
delt = NewParallelTruncatedGaussian._delta(x, y)
one_minus_delt = 1 - delt
out1 = NewParallelTruncatedGaussian._P_1(x, y-x)
out2 = one_minus_delt / (delt*erfcx(-y) - erfcx(-x) + EPS)
out3 = one_minus_delt / (erfcx(x) - delt*erfcx(y) + EPS)
out4 = (one_minus_delt*torch.exp(-x**2)) / (erf(y)-erf(x) + EPS)
# Conditions
x_cond, y_cond = x.detach(), y.detach() # Make sure that gradients do not go through the conditions
out1_cond = torch.abs(x_cond - y_cond) < 1e-7 # tensor([False, False, False, False])
out2_cond = torch.logical_and( torch.logical_and(x_cond<=0, y_cond<=0), torch.logical_not(out1_cond) ) # tensor([False, True, False, False])
out3_cond = torch.logical_and( torch.logical_and(x_cond>=0, y_cond>=0), torch.logical_not(out1_cond) ) # tensor([False, False, False, True])
out4_cond = torch.logical_and( torch.logical_and( torch.logical_not(out1_cond), torch.logical_not(out2_cond) ),
torch.logical_not(out3_cond) ) # All the other conditions must be false # tensor([ True, False, True, False])
# Final expression
final_out = torch.where(out1_cond,out1,0) + torch.where(out2_cond,out2,0) + torch.where(out3_cond,out3,0) + torch.where(out4_cond,out4,0)
return final_out
```

This works perfectly for the forward pass but, at the backward pass (i.e., for computing the gradients), **I get NaN gradients**. What happens is that some of these expressions result in *NaNs* or *infs* for some positions. Nonetheless, the values in these positions are discarded by *torch.where* (i.e., they are substituted by 0). Despite this, I still get NaN gradients for the final result (*final_out*), **even though the values which result in NaN gradients are not used in calculating final_out, since torch.where discards them**. In other words, I think that if I were able to **substitute NaN gradients for any other value** (e.g., 0), the gradient calculations would work perfectly.

Any suggestions on how to solve this issue?