I am implementing a truncated gaussian distribution in Pytorch (https://github.com/TheAeryan/stable-truncated-gaussian). In order to do that in a numerically stable manner, I need conditional code (if-then-else) for calculating the mean, variance and log-probs in one way or another, depending on the specific input values. Here is an example:
def _F_1(x, y):
if torch.abs(x) > torch.abs(y):
out = SequentialTruncatedGaussian._F_1(y, x)
elif torch.abs(x - y) < 1e-7:
out = SequentialTruncatedGaussian._P_1(x, y-x)
elif x <= 0 and y <= 0:
delt = SequentialTruncatedGaussian._delta(x, y)
out = (1 - delt) / (delt*erfcx(-y) - erfcx(-x))
elif x >= 0 and y >= 0:
delt = SequentialTruncatedGaussian._delta(x, y)
out = (1 - delt) / (erfcx(x) - delt*erfcx(y))
else:
delt = SequentialTruncatedGaussian._delta(x, y)
out = ((1-delt)*torch.exp(-x**2)) / (erf(y)-erf(x))
return out
I am now adapting my code to make it parallel, so I need a way to insert these if-then-else conditions into Pytorch. What I do is that I compute the out values associated with each condition and, then, I use torch.where to select the correct value according to its associated condition. This is how I have implemented the _F_1 function above:
def _F_1(x_, y_):
# All values in tensor @x must be smaller than values in tensor @y
x = torch.where(torch.abs(x_)<=torch.abs(y_),x_,y_)
y = torch.where(torch.abs(y_)>=torch.abs(x_),y_,x_)
# Values
delt = NewParallelTruncatedGaussian._delta(x, y)
one_minus_delt = 1 - delt
out1 = NewParallelTruncatedGaussian._P_1(x, y-x)
out2 = one_minus_delt / (delt*erfcx(-y) - erfcx(-x) + EPS)
out3 = one_minus_delt / (erfcx(x) - delt*erfcx(y) + EPS)
out4 = (one_minus_delt*torch.exp(-x**2)) / (erf(y)-erf(x) + EPS)
# Conditions
x_cond, y_cond = x.detach(), y.detach() # Make sure that gradients do not go through the conditions
out1_cond = torch.abs(x_cond - y_cond) < 1e-7 # tensor([False, False, False, False])
out2_cond = torch.logical_and( torch.logical_and(x_cond<=0, y_cond<=0), torch.logical_not(out1_cond) ) # tensor([False, True, False, False])
out3_cond = torch.logical_and( torch.logical_and(x_cond>=0, y_cond>=0), torch.logical_not(out1_cond) ) # tensor([False, False, False, True])
out4_cond = torch.logical_and( torch.logical_and( torch.logical_not(out1_cond), torch.logical_not(out2_cond) ),
torch.logical_not(out3_cond) ) # All the other conditions must be false # tensor([ True, False, True, False])
# Final expression
final_out = torch.where(out1_cond,out1,0) + torch.where(out2_cond,out2,0) + torch.where(out3_cond,out3,0) + torch.where(out4_cond,out4,0)
return final_out
This works perfectly for the forward pass but, at the backward pass (i.e., for computing the gradients), I get NaN gradients. What happens is that some of these expressions result in NaNs or infs for some positions. Nonetheless, the values in these positions are discarded by torch.where (i.e., they are substituted by 0). Despite this, I still get NaN gradients for the final result (final_out), even though the values which result in NaN gradients are not used in calculating final_out, since torch.where discards them. In other words, I think that if I were able to substitute NaN gradients for any other value (e.g., 0), the gradient calculations would work perfectly.
Any suggestions on how to solve this issue?