Model returning nan as output

I’m trying to implement a variant of capsule network where the matrix multiplication is replaced by element-wise multiplication with a vector. During training (mostly after the first backpropagation) the outputs become nan. I tried using gradient clipping, but it didn’ work. I’m working with MNIST dataset and I’m normalizing it before training.

## Training data loading and normalizing
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST('/mnist/', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,shuffle=True, num_workers=2, batch_size = 10)

## Model Architecture
class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              stride=1
                              )

    def forward(self, x):
        return F.relu(self.conv(x))


class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, num_routes=32 * 6 * 6):
        super(PrimaryCaps, self).__init__()
        self.num_routes = num_routes
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)
            for _ in range(num_capsules)])

    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)

        u = u.view(x.size(0), self.num_routes, -1)
        return self.squash(u)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, in_channels, 1))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)

        W = torch.cat([self.W] * batch_size, dim=0)
        u_hat = W * x

        b_ij = torch.zeros(1, self.num_routes, self.num_capsules, 1)
        if USE_CUDA:
            b_ij = b_ij.cuda()

        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij, dim=1)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)

            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


class CapsNet(nn.Module):
    def __init__(self, config=None):
        super(CapsNet, self).__init__()
        if config:
            self.conv_layer = ConvLayer(config.cnn_in_channels, config.cnn_out_channels, config.cnn_kernel_size)
            self.primary_capsules = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, config.pc_out_channels,
                                                config.pc_kernel_size, config.pc_num_routes)
            self.digit_capsules = DigitCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels,
                                            config.dc_out_channels)
        else:
            self.conv_layer = ConvLayer()
            self.primary_capsules = PrimaryCaps()
            self.digit_capsules = DigitCaps()

        self.mse_loss = nn.MSELoss()

    def forward(self, data):
        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        return output

class Config_1:
    def __init__(self, in_channels):
            # CNN (cnn)
            self.cnn_in_channels = in_channels
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 3

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 3
            self.pc_num_routes = 32 * 6 * 6

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 6 * 6
            self.dc_in_channels = 18
            self.dc_out_channels = 1

class Config_2:
    def __init__(self, in_channels):
            # CNN (cnn)
            self.cnn_in_channels = in_channels
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 3

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 3
            self.pc_num_routes = 32 * 6 * 6

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 6 * 6
            self.dc_in_channels = 8
            self.dc_out_channels = 1

class Config_3:
    def __init__(self, in_channels):
            # CNN (cnn)
            self.cnn_in_channels = in_channels
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 3

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 3
            self.pc_num_routes = 32 * 6 * 6

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 6 * 6
            self.dc_in_channels = 2
            self.dc_out_channels = 1

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(1,32,kernel_size = (3,3))
        self.conv2 = nn.Conv2d(32,48,kernel_size = (3,3))
        self.conv3 = nn.Conv2d(48,64,kernel_size = (3,3))
        self.conv4 = nn.Conv2d(64,80,kernel_size = (3,3))
        self.conv5 = nn.Conv2d(80,96,kernel_size = (3,3))
        self.conv6 = nn.Conv2d(96,112,kernel_size = (3,3))
        self.conv7 = nn.Conv2d(112,128,kernel_size = (3,3))
        self.conv8 = nn.Conv2d(128,144,kernel_size = (3,3))
        self.conv9 = nn.Conv2d(144,160,kernel_size = (3,3))
        self.caps_a = CapsNet(Config_1(64))
        self.caps_b = CapsNet(Config_2(112))
        self.caps_c = CapsNet(Config_3(160))
        self.merge_weight1= nn.Parameter(torch.randn(1))
        self.merge_weight2= nn.Parameter(torch.randn(1))
        self.merge_weight3= nn.Parameter(torch.randn(1))
        self.relu = nn.ReLU()
        
    def forward(self,x):
        branch1 = self.conv1(x)
        branch1 = self.relu(branch1)
        branch1 = self.conv2(branch1)
        branch1 = self.relu(branch1)
        branch1 = self.conv3(branch1)
        branch1 = self.relu(branch1)
        branch2 = self.conv4(branch1)
        branch2 = self.relu(branch2)
        branch2 = self.conv5(branch2)
        branch2 = self.relu(branch2)
        branch2 = self.conv6(branch2)
        branch2 = self.relu(branch2)
        branch3 = self.conv7(branch2)
        branch3 = self.relu(branch3)
        branch3 = self.conv8(branch3)
        branch3 = self.relu(branch3)
        branch3 = self.conv9(branch3)
        branch3 = self.relu(branch3)

        branch1_out = self.caps_a(branch1)
        branch2_out = self.caps_b(branch2)
        branch3_out = self.caps_c(branch3)

     
        branch1_out = torch.sqrt((branch1_out ** 2).sum(dim=2, keepdim=True)).view(-1,10)
        branch2_out = torch.sqrt((branch2_out ** 2).sum(dim=2, keepdim=True)).view(-1,10)
        branch3_out = torch.sqrt((branch3_out ** 2).sum(dim=2, keepdim=True)).view(-1,10)
        
        out1 = self.merge_weight1 * branch1_out
        out2 = self.merge_weight1 * branch2_out
        out3 = self.merge_weight1 * branch3_out
        stack = torch.stack([out1,out3,out3],dim = 0)
        summed = torch.sum(stack,dim = 0)
        return summed

net = Net()
if USE_CUDA:
  net = net.cuda()
criterion = nn.CrossEntropyLoss()
lr = 0.01
optimizer = torch.optim.Adam(net.parameters(),lr = lr)

## Training of the network 

for epoch in range(5):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        if USE_CUDA:
          inputs = inputs.cuda()
          labels = labels.cuda()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        if torch.isnan(loss):
          print("nan Loss")
          running_loss = 0.0
          continue
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(net.parameters(), 0.1)
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:  
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0
print('Finished Training')

Here’s the sample output

[1,   100] loss: 2.301
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
nan Loss
............

Any idea on what is causing this or any suggestions on how to fix it please?

Could you post your model definition, so that we could have a look at it and debug?

Although I am not sure if this is the reason but your learning rate = 0.01 which seems a bit high. Whenever I have used a learning rate of 0.01 with Adam optimizer I normally get NaN very early in the training. But I was using different architecture and Tensorflow framework. Just try reducing your learning rate and see if it helps.

I have added the model definition in the post.

I’ve tried with lowering the learning rate, it is still the same.

Sounds stupid, but have you tried printing the outputs of branch_outs and squared_norms before and after you square and sum them?

I’ve tried printing it for few batches, the values being printed were all real values and not nans

The values of squared_norm in PrimaryCaps explode and create the NaNs.
In the last iteration before the NaNs are raised PrimaryCaps creates tensors with these statistics:

print(squared_norm.min(), squared_norm.max())
> tensor(1.4527e+20, device='cuda:0', grad_fn=<MinBackward1>) tensor(5.3600e+26, device='cuda:0', grad_fn=<MaxBackward1>)
print(input_tensor.min(), input_tensor.max())
> tensor(-1.6448e+13, device='cuda:0', grad_fn=<MinBackward1>) tensor(1.8647e+13, device='cuda:0', grad_fn=<MaxBackward1>)
print(output_tensor.min(), output_tensor.max())
> tensor(nan, device='cuda:0', grad_fn=<MinBackward1>) tensor(nan, device='cuda:0', grad_fn=<MaxBackward1>)

This can be explained via:

squared_norm = torch.tensor(1e26)
input_tensor = torch.tensor(1e13)
output_tensor = squared_norm * input_tensor / ((1+squared_norm) * torch.sqrt(squared_norm)) # divisor is Inf
print(output_tensor)
> tensor(nan)

So it seems your overall training might not be stable.

Thanks for the input. I rewrote the squash function as given below

def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        denom = ((1. + squared_norm) * torch.sqrt(squared_norm))
        if torch.isinf(denom).sum().item()>0:
          output_tensor = input_tensor / torch.sqrt(squared_norm)
        else:
          output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

I think this solves it.

I ran into the same problem. I fixed it by removing nans in the training data and scaling features. Both StandardScaler() and MinMaxScaler() from sklearn worked cf. https://www.analyticsvidhya.com/blog/2020/04/feature-scaling-machine-learning-normalization-standardization/)

I also had the same problem but I forgot to add optimizer.zero_grad() so silly mistake but this was a mistake I made so it could be the same for you just wanted to share