Relu function results in nans

I am using a capsule networks model, and at a certain point of the training the Relu function located in the conv layer results in nans, using
with torch.autograd.detect_anomaly(): , I get this error

Traceback (most recent call last):

File “C:\Users\obouldjedr\Desktop\lastcode4\cp\test_capsnet.py”, line 532, in
train(capsule_net, optimizer,trainloaderIMUhand, e)

File “C:\Users\obouldjedr\Desktop\lastcode4\cp\test_capsnet.py”, line 217, in train
loss.backward()

File “C:\Users\obouldjedr\Anaconda3\lib\site-packages\torch\tensor.py”, line 245, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File “C:\Users\obouldjedr\Anaconda3\lib\site-packages\torch\autograd_init_.py”, line 145, in backward
Variable._execution_engine.run_backward(

RuntimeError: Function ‘DivBackward0’ returned nan values in its 0th output.

Does anyone knows why ? thanks

This might possibly be due to exploding gradients. You should try to clip the value of gradient using torch.nn.utils.clip_grad_value or torch.nn.utils.clip_grad_norm.

To verify this, you can print the value of gradients during training and see if they go to infinity.

I’m a bit confused about the information that the NaN values are created by the ReLU layer, while the error points to DivBackward0. Could you check where a division is used and make sure it’s creating valid outputs?

Indeed, I forgot to mention this detail. Before getting nans (all the tensor returned as nan by relu ) , I got this in earlier level , in fact there is a function called squashing in which there is kind of making the values between 0 and 1 below the code:

def squash(self, input_tensor):
squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
return output_tensor

Using the is.nan.any tool on every tensor the first nan appears in the tensor resulted by the suqashing operation (output_tensor) and then this will be propagated to the rest of tensors till the loss turn on to be nan (out of sudden). What I did is I used the new integrated function in pytorch called nan to num to turn them into 0. after this I started to get all the tensors to nan out of the relu function related to conv layer.
Below the whole code of the capsule net:

##########################################################

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

USE_CUDA = True if torch.cuda.is_available() else False

class ConvLayer(nn.Module):
def init(self, in_channels=1, out_channels=256, kernel_size=9):
super(ConvLayer, self).init()

    self.conv = nn.Conv2d(in_channels=in_channels,
                          out_channels=out_channels,
                          kernel_size=kernel_size,
                          stride=1
                          )

def forward(self, x):
    return F.relu(self.conv(x))

class PrimaryCaps(nn.Module):
def init(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, num_routes=32 * 6 * 6):
super(PrimaryCaps, self).init()
self.num_routes = num_routes
self.capsules = nn.ModuleList([
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)
for _ in range(num_capsules)])

def forward(self, x):
    u = [capsule(x) for capsule in self.capsules]
    u = torch.stack(u, dim=1)
    u = u.view(x.size(0), self.num_routes, -1)
    return self.squash(u)

def squash(self, input_tensor):
    squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
    output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
    return output_tensor

class DigitCaps(nn.Module):
def init(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
super(DigitCaps, self).init()

    self.in_channels = in_channels
    self.num_routes = num_routes
    self.num_capsules = num_capsules

    self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

def forward(self, x):
    batch_size = x.size(0)
    x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)

    W = torch.cat([self.W] * batch_size, dim=0)
    u_hat = torch.matmul(W, x)

    b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
    if USE_CUDA:
        b_ij = b_ij.cuda()

    num_iterations = 3
    for iteration in range(num_iterations):
        c_ij = F.softmax(b_ij, dim=1)
        c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

        s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
        v_j = self.squash(s_j)

        if iteration < num_iterations - 1:
            a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
            b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

    return v_j.squeeze(1)

def squash(self, input_tensor):
    squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
    output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
    return output_tensor

class Decoder(nn.Module):
def init(self, input_width=28, input_height=28, input_channel=1):
super(Decoder, self).init()
self.input_width = input_width
self.input_height = input_height
self.input_channel = input_channel
self.reconstraction_layers = nn.Sequential(
nn.Linear(16 * 10, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, self.input_height * self.input_height * self.input_channel),
nn.Sigmoid()
)

def forward(self, x, data):
    classes = torch.sqrt((x ** 2).sum(2))
    classes = F.softmax(classes, dim=0)

    _, max_length_indices = classes.max(dim=1)
    masked = Variable(torch.sparse.torch.eye(10))
    if USE_CUDA:
        masked = masked.cuda()
    masked = masked.index_select(dim=0, index=Variable(max_length_indices.squeeze(1).data))
    t = (x * masked[:, :, None, None]).view(x.size(0), -1)
    reconstructions = self.reconstraction_layers(t)
    reconstructions = reconstructions.view(-1, self.input_channel, self.input_width, self.input_height)
    return reconstructions, masked

class CapsNet(nn.Module):
def init(self, config=None):
super(CapsNet, self).init()
if config:
self.conv_layer = ConvLayer(config.cnn_in_channels, config.cnn_out_channels, config.cnn_kernel_size)
self.primary_capsules = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, config.pc_out_channels,
config.pc_kernel_size, config.pc_num_routes)
self.digit_capsules = DigitCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels,
config.dc_out_channels)
self.decoder = Decoder(config.input_width, config.input_height, config.cnn_in_channels)
else:
self.conv_layer = ConvLayer()
self.primary_capsules = PrimaryCaps()
self.digit_capsules = DigitCaps()
self.decoder = Decoder()

    self.mse_loss = nn.MSELoss()

def forward(self, data):
    output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
    reconstructions, masked = self.decoder(output, data)
    return output, reconstructions, masked

def loss(self, data, x, target, reconstructions):
    return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)

def margin_loss(self, x, labels, size_average=True):
    batch_size = x.size(0)

    v_c = torch.sqrt((x ** 2).sum(dim=2, keepdim=True))

    left = F.relu(0.9 - v_c).view(batch_size, -1)
    right = F.relu(v_c - 0.1).view(batch_size, -1)

    loss = labels * left + 0.5 * (1.0 - labels) * right
    loss = loss.sum(dim=1).mean()

    return loss

def reconstruction_loss(self, data, reconstructions):
    loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
    return loss * 0.0005

##################################################################

So the source of the first nan is the output tensor in the squashing function at the digit capsule I used nan to num to get rid of them then the relu in conv layer started to return all the tensor as nan .
PS: during training, I saw a lot of 0.00000 , the code I used is based on conv1D as I use time series instead of images.
Thank you

I don’t think that replacing the NaNs with other values would solve the issue, as I would assume that the backward pass would still create invalid gradients (you should verify it). The proper fix would be to avoid dividing by zero (or a very small number) in:

output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))

E.g. you could add a small eps value to the division.

I applied that by adding a 0.1 to avoid division by 0, now the program takes a bit longer to makes the nan appears but it seems this time related to the SQRT operation (even if I added a small term there as well) refering to the error message :

Traceback (most recent call last):

File “C:\Users\obouldjedr\Desktop\lastcode4\cp\test_capsnet.py”, line 538, in
train(capsule_net, optimizer,trainloaderIMUhand, e)

File “C:\Users\obouldjedr\Desktop\lastcode4\cp\test_capsnet.py”, line 223, in train
loss.backward()

File “C:\Users\obouldjedr\Anaconda3\lib\site-packages\torch\tensor.py”, line 245, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File “C:\Users\obouldjedr\Anaconda3\lib\site-packages\torch\autograd_init_.py”, line 145, in backward
Variable._execution_engine.run_backward(

RuntimeError: Function ‘SqrtBackward’ returned nan values in its 0th output.

What do you think the solution will be ?, thank you in advance for your help

The torch.sqrt method would create an Inf gradient for a zero input and a NaN output and gradient for a negative input, so you could add an eps value there as well or make sure the input is a positive number:

x = torch.tensor([0.], requires_grad=True)
y = torch.sqrt(x)
y.backward()
print(x.grad)
> tensor([inf])
2 Likes