Incorrect gradient calculation with torch.where and NaNs

I am implementing a truncated gaussian distribution in Pytorch (https://github.com/TheAeryan/stable-truncated-gaussian). In order to do that in a numerically stable manner, I need conditional code (if-then-else) for calculating the mean, variance and log-probs in one way or another, depending on the specific input values. Here is an example:

def _F_1(x, y):
		if torch.abs(x) > torch.abs(y):
			out = SequentialTruncatedGaussian._F_1(y, x)
		elif torch.abs(x - y) < 1e-7:
			out = SequentialTruncatedGaussian._P_1(x, y-x)
		elif x <= 0 and y <= 0:
			delt = SequentialTruncatedGaussian._delta(x, y)
			out = (1 - delt) / (delt*erfcx(-y) - erfcx(-x))  
		elif x >= 0 and y >= 0:
			delt = SequentialTruncatedGaussian._delta(x, y)
			out = (1 - delt) / (erfcx(x) - delt*erfcx(y))
		else:
			delt = SequentialTruncatedGaussian._delta(x, y)
			out = ((1-delt)*torch.exp(-x**2)) / (erf(y)-erf(x))

		return out

I am now adapting my code to make it parallel, so I need a way to insert these if-then-else conditions into Pytorch. What I do is that I compute the out values associated with each condition and, then, I use torch.where to select the correct value according to its associated condition. This is how I have implemented the _F_1 function above:

def _F_1(x_, y_):
		# All values in tensor @x must be smaller than values in tensor @y
		x = torch.where(torch.abs(x_)<=torch.abs(y_),x_,y_)
		y = torch.where(torch.abs(y_)>=torch.abs(x_),y_,x_)

		# Values
		delt = NewParallelTruncatedGaussian._delta(x, y)
		one_minus_delt = 1 - delt
		out1 = NewParallelTruncatedGaussian._P_1(x, y-x)
		out2 = one_minus_delt / (delt*erfcx(-y) - erfcx(-x) + EPS)
		out3 = one_minus_delt / (erfcx(x) - delt*erfcx(y) + EPS)
		out4 = (one_minus_delt*torch.exp(-x**2)) / (erf(y)-erf(x) + EPS)

		# Conditions
		x_cond, y_cond = x.detach(), y.detach() # Make sure that gradients do not go through the conditions

		out1_cond = torch.abs(x_cond - y_cond) < 1e-7 # tensor([False, False, False, False])
		out2_cond = torch.logical_and( torch.logical_and(x_cond<=0, y_cond<=0), torch.logical_not(out1_cond) ) # tensor([False,  True, False, False])
		out3_cond = torch.logical_and( torch.logical_and(x_cond>=0, y_cond>=0), torch.logical_not(out1_cond) ) # tensor([False, False, False,  True]) 
		out4_cond = torch.logical_and( torch.logical_and( torch.logical_not(out1_cond), torch.logical_not(out2_cond) ),
									torch.logical_not(out3_cond) ) # All the other conditions must be false # tensor([ True, False,  True, False])

		# Final expression
		final_out = torch.where(out1_cond,out1,0) + torch.where(out2_cond,out2,0) + torch.where(out3_cond,out3,0) + torch.where(out4_cond,out4,0)

		return final_out

This works perfectly for the forward pass but, at the backward pass (i.e., for computing the gradients), I get NaN gradients. What happens is that some of these expressions result in NaNs or infs for some positions. Nonetheless, the values in these positions are discarded by torch.where (i.e., they are substituted by 0). Despite this, I still get NaN gradients for the final result (final_out), even though the values which result in NaN gradients are not used in calculating final_out, since torch.where discards them. In other words, I think that if I were able to substitute NaN gradients for any other value (e.g., 0), the gradient calculations would work perfectly.

Any suggestions on how to solve this issue?

Could you post an example input which would yield valid values in the first example but NaNs in the latter?

Hi Carlos!

First, this is a known issue (with no simple fix for torch.where()). See,
for example, github issues 68425 and 70342.

Second, I believe that the best fix is to avoid producing nans (or the values
that subsequently produce nans) in the forward pass, even in the “branch
not taken” of the torch.where(). Instead, compute in their places “safe”
(non-nan-producing) values that can be incorrect, because torch.where()
will successfully discard the incorrect (but non-nan) gradients that they
lead to.

For a concrete example of how feed “safe” values to the “branch not taken”
in a problem analogous to yours, see this post:

Best.

K. Frank

1 Like

Thank you @ptrblck and @KFrank for your answers, and sorry for the late reply.

@KFrank, I looked into what you mentioned and, yes, that seems to be the root of my problem. In order to solve it, I first tried to use masked tensors but, unfortunately, they don’t support the erfcx operation (yet). Next, I tried to implement “my own version” of masked tensors by masking the input values for each separate operation so that NaNs are not produced (as you suggested), applying the operations, and then masking the outputs again before adding up the results. This seems to work both for the forward and backward pass. Here is my new implementation of the function _F_1:

def _F_1(x_, y_):
		# All values in tensor @x must be smaller than values in tensor @y
		x = torch.where(torch.abs(x_)<=torch.abs(y_),x_,y_)
		y = torch.where(torch.abs(y_)>=torch.abs(x_),y_,x_)

		# Obtain masks
		with torch.no_grad():
			out1_cond = torch.abs(x - y) < 1e-7 # tensor([False, False, False, False])
			out2_cond = torch.logical_and( torch.logical_and(x<=0, y<=0), torch.logical_not(out1_cond) ) # tensor([False,  True, False, False])
			out3_cond = torch.logical_and( torch.logical_and(x>=0, y>=0), torch.logical_not(out1_cond) ) # tensor([False, False, False,  True]) 
			out4_cond = torch.logical_and( torch.logical_and( torch.logical_not(out1_cond), torch.logical_not(out2_cond) ),
										   torch.logical_not(out3_cond) ) # All the other conditions must be false


		# Mask input values for each operation
		x1, y1 = where(out1_cond, x, 0), where(out1_cond, y, 0)
		x2, y2 = where(out2_cond, x, -1), where(out2_cond, y, 0)
		x3, y3 = where(out3_cond, x, 0), where(out3_cond, y, 1)
		x4, y4 = where(out4_cond, x, 0), where(out4_cond, y, 1)

		# Apply operations
		delt = ParallelTruncatedGaussian._delta(x, y)
		one_minus_delt = 1 - delt

		out1_m = ParallelTruncatedGaussian._P_1(x1, y1-x1)
		out2_m = one_minus_delt / (delt*erfcx(-y2) - erfcx(-x2))
		out3_m = one_minus_delt / (erfcx(x3) - delt*erfcx(y3))
		out4_m = (one_minus_delt*torch.exp(-x4**2)) / (erf(y4)-erf(x4))

		# Unmask tensors, by setting masked values to 0
		out1 = where(out1_cond, out1_m, 0)
		out2 = where(out2_cond, out2_m, 0)
		out3 = where(out3_cond, out3_m, 0)
		out4 = where(out4_cond, out4_m, 0)

		# Add them up into a single tensor
		final_out = out1 + out2 + out3 + out4

		return final_out

Also, I have acknowledged your help, @KFrank and @ptrblck, in my github repo :grin:.

1 Like