Getting NaNs in gradients while training Dice loss

I have a network which I’m trying to train a network for 2-class pixel-wise segmentation. To handle skew in the classes, I’m using the Dice loss. It works well with a baseline network that just predicts the probability of the pixel being 1. But in a second network, the outputs for each pixel are parameters of a Beta distribution, and samples are taken from it. The mean of these samples is used to calculate the dice loss. This works upto a certain point, from which it starts giving NaNs in the gradients. The loss value isn’t nan, however. The NaNs in gradients also do not occur when I train with BCE loss. I’m using a smooth dice loss with smoothness parameter = 1.

Here are some snippets:

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable

class WNetBeta(nn.Module):
	def __init__(self, in_ch, out_ch, C0=32):
		super(WNetBeta, self).__init__()
		self.conv0 = conv(in_ch, C0, (1, 3, 3), padding=(0, 1, 1))

		self.resblock1 = ResBlock(C0, C0)
		self.resblock2 = ResBlock(C0, C0)
		self.tconv1 = tconv(C0, C0)
		
		self.avgsample = nn.AvgPool3d((1, 2, 2), stride=(1, 2, 2))

		self.resblock3 = ResBlock(C0, C0)
		self.resblock4 = ResBlock(C0, C0)

		self.tconv2 = tconv(C0, C0)

		self.pconv1 = pconv(C0, out_ch, (1, 3, 3), padding=(0, 1, 1))

	def forward(self, x):
		data = self.conv0(x)
		data = self.resblock1(data)
		data = self.resblock2(data)
		data = self.tconv1(data)
		data = self.avgsample(data)
		
		data = self.resblock3(data)
		data = self.resblock4(data)
		data = self.tconv2(data)
		out = self.pconv1(data)

		## get alpha, beta
		logalpha = out[:, 0]
		logbeta  = out[:, 1]

		alpha = torch.exp(logalpha)
		beta  = torch.exp(logbeta)

		# Get samples now
		T = 10
		alphar = alpha.unsqueeze(1).repeat(1, T, 1, 1, 1, 1)
		betar  =  beta.unsqueeze(1).repeat(1, T, 1, 1, 1, 1)
		m = Beta(alphar, betar)
		p = m.rsample().mean(1)

		return p, logalpha, logbeta

Think of the blocks ResBlock, pconv, etc. as conv blocks with different parameters. I have not pasted their snippets for brevity, and because they’re probably not relevant here.

The baseline is the same except that logalpha and logbeta are replaced by logits of the segmentation. The dice loss (rather 1- dice_loss) is given by the following code:

	def dice_loss(pred, target, smooth=1.):
		"""This definition generalize to real valued pred and target vector.
		This should be differentiable.
		pred: tensor with first dimension as batch
		target: tensor with first dimension as batch
		"""
		# have to use contiguous since they may from a torch.view op
		# iflat = B * N
		# tflat = B * N
		iflat = torch.flatten(pred.contiguous(), 1)
		tflat = torch.flatten(target.contiguous(), 1)
		intersection = (iflat * tflat).sum(1)
		A_sum = iflat.sum(1)
		B_sum = tflat.sum(1)
		
		loss = 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth)).mean()
		return loss

In the case of network with beta parameterization, the dice loss is taken with the mean of the samples.

Can anyone tell me what is going wrong in here? Thank you.

If it gives Nan's in the gradients then your parameters would become Nan, and you’d see Nan in the loss in the next iteration as well. I take that “it starts giving NaNs in the gradients. The loss value isn’t nan, however.” means you checked the gradients manually and saw the loss wasn’t Nan but the gradients were?

1 Like

In the current setting, I skip the optim.step() if there’s a nan in the gradients.

I checked, and if the parameters become nan, then the dice loss suddenly drops, but doesn’t become nan.

My bad, the dice loss drops as in if the average dice score were 0.66 (good segmentation) then it became 0.04 (pretty bad)

when NaN's arise all computations involving them become NaN as well, its curious your parameters turning NaN are still leading to real number losses. It might be you have a NaN catcher implemented somewhere. AFAIK it is not good practice to conceal NaN's this way as it prevents one from figuring out what is causing them in the first place. Regardless of the other problems, I’d say this needs to be figured out, even though it may seem paradoxical that you’d want NaN's to show themselves.

While I dont know about the code in particular, I’d suggest using backward hooks, or retain_grad to look at the gradients of all the layers to figure out where NaN's first pop up. I figure NaN is basically like inf-inf, inf/inf or 0/0. As number representations are bounded, you can see all these cases, even the inf's will arise due to a division by 0 somewhere.

@tumble-weed I’m unaware of such a NaN catcher but I suspect that the Beta distribution may output samples from a ‘trivial’ distribution like the uniform distribution or something when it gets NaN as input. I suspect the NaNs arise in the Dice loss function, because I never got a NaN when I trained using the cross entropy loss. I have given the network code. I’ll try using hooks or retain_grad. Could you point me to some places to look up such kinds of issues and debugging?

Edit: Tried using hooks to find out. Turns out that the NaNs are due to high values of gradients at some layers. For example, here is an example of an output by the hooks:

inp pconv4, 0, 1.60642921401e-05, -2.30273017223e-05
inp pconv4, 1, 0.1004197523, -0.0698678120971
inp pconv4, 2, -0.114721596241, -0.114721596241
out pconv4, 0, 1.54108820425e-05, -1.4961476154e-05
inp pconv3, 0, 2.00003178179e-05, -2.27984455705e-05
inp pconv3, 1, -3.40282346639e+38, 3.40282346639e+38
inp pconv3, 2, -3.40282346639e+38, 3.40282346639e+38
out pconv3, 0, 8.05776580819e-05, -0.000112374822493
inp tconv4, 0, 1.4272424778e-05, -2.27984455705e-05
inp tconv4, 1, -3.40282346639e+38, 3.40282346639e+38
out tconv4, 0, 2.00003178179e-05, -2.27984455705e-05
inp res10, 0, 1.41433047247e-05, -8.82773838384e-06
inp res10, 1, 1.41433047247e-05, -8.82773838384e-06
out res10, 0, 1.41433047247e-05, -8.82773838384e-06
inp res9, 0, 1.33553148771e-05, -9.24733012653e-06
inp res9, 1, 1.33553148771e-05, -9.24733012653e-06
out res9, 0, 1.33553148771e-05, -9.24733012653e-06
inp res8, 0, 1.54191238835e-05, -1.03389274955e-05
inp res8, 1, 1.54191238835e-05, -1.03389274955e-05
out res8, 0, 1.54191238835e-05, -1.03389274955e-05
inp pconv2, 0, 9.37501172302e-05, -0.000109870823508
inp pconv2, 1, -3.40282346639e+38, 3.40282346639e+38
inp pconv2, 2, -3.40282346639e+38, 3.40282346639e+38
out pconv2, 0, 0.000202267314307, -0.000283791217953
inp tconv3, 0, 3.52506276613e-06, -9.23275365494e-06
inp tconv3, 1, -3.40282346639e+38, 3.40282346639e+38
out tconv3, 0, 4.74067655887e-06, -9.23275365494e-06
inp res7, 0, 2.18382069761e-07, -1.71947704075e-07
inp res7, 1, 2.18382069761e-07, -1.71947704075e-07
out res7, 0, 2.18382069761e-07, -1.71947704075e-07
inp res6, 0, nan, nan
inp res6, 1, nan, nan
out res6, 0, nan, nan
inp res5, 0, nan, nan
inp res5, 1, nan, nan
out res5, 0, nan, nan
inp AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0), 9.59493546443e+14, -7.67205847335e+14
out AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0), nan, nan
inp pconv1, 0, 2.66912047664e-05, -2.00852118724e-05
inp pconv1, 1, -3.40282346639e+38, 3.40282346639e+38
inp pconv1, 2, -3.40282346639e+38, 3.40282346639e+38
out pconv1, 0, 4.03012127208e-05, -6.12536241533e-05
inp tconv2, 0, 9.59493546443e+14, -5.2268440014e+14
inp tconv2, 1, -3.40282346639e+38, 3.40282346639e+38
out tconv2, 0, 9.59493546443e+14, -7.67205847335e+14
inp res4, 0, 3.21442163982e+14, -2.90692882498e+14
inp res4, 1, 3.21442163982e+14, -2.90692882498e+14
out res4, 0, 3.21442163982e+14, -2.90692882498e+14
inp res3, 0, 4.12539561181e+14, -3.18285765673e+14
inp res3, 1, 4.12539561181e+14, -3.18285765673e+14
out res3, 0, 4.12539561181e+14, -3.18285765673e+14
inp AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0), nan, nan
out AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0), 5.69794319352e+14, -5.25037538902e+14
inp tconv1, 0, nan, nan
inp tconv1, 1, -3.40282346639e+38, 3.40282346639e+38
out tconv1, 0, nan, nan
inp res2, 0, nan, nan
inp res2, 1, nan, nan
out res2, 0, nan, nan
inp conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), nan, nan
inp conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), -3.40282346639e+38, 3.40282346639e+38
out conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), nan, nan
inp conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), nan, nan
inp conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), -3.40282346639e+38, 3.40282346639e+38
out conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), nan, nan
inp res1, 0, nan, nan
inp res1, 1, nan, nan
out res1, 0, nan, nan
inp conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), nan, nan
inp conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), -3.40282346639e+38, 3.40282346639e+38
out conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), nan, nan
inp conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), nan, nan
inp conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), -3.40282346639e+38, 3.40282346639e+38
out conv(
  (layer): Sequential(
    (0): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
), nan, nan
inp conv1, 0, nan, nan
inp conv1, 1, -3.40282346639e+38, 3.40282346639e+38
out conv1, 0, nan, nan

There are a few places where the gradients’ max and min come out to be very high, and then the nans follow. I’m using residual nets, Batchnorm, and PReLU to prevent gradient explosion.