Need help understanding why the following custom layer I implemented is not numerically stable? I keep getting NAN using it

import math
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

torch.set_printoptions(precision=10)

class LargeMarginSoftmaxLinear(nn.Module):
    def __init__(self, in_features, out_features, m, _lambda, use_cuda, device):
        super(LargeMarginSoftmaxLinear, self).__init__()
        self.w = Parameter(torch.zeros(out_features, in_features).to(device))
        self.reset_parameters()
        self.m = int(m)
        self._lambda = float(_lambda)
        self.m_choose_n_map = torch.zeros(self.m+1)
        self.k_map = torch.zeros(self.m+1)
        for i in range(self.m+1):
            n = k = i
            self.m_choose_n_map[i] = math.factorial(self.m) / math.factorial(n) / math.factorial(self.m-n)
            self.k_map[i] = math.cos(k * math.pi / self.m)

        # any better way than this for device agnostic code?
        self.use_cuda = use_cuda
        self.device = device
        if self.use_cuda:
            self.m_choose_n_map = self.m_choose_n_map.to(self.device)
            self.k_map = self.k_map.to(self.device)

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.w.size(1))
        self.w.data.uniform_(-stdv, stdv)

    def determine_k(self, cos_theta):
        k = torch.zeros_like(cos_theta)
        for i in range(self.m-1):
            k += (self.k_map[i+1] >= cos_theta).type(cos_theta.dtype)
        return k

    def evaluate_cos_m_theta(self, cos_theta):
        sin_square_theta = 1 - cos_theta.pow(2)
        n = torch.range(0, self.m // 2, dtype=cos_theta.dtype).view(-1, 1)
        if self.use_cuda:
            sin_square_theta = sin_square_theta.to(self.device)
            n = n.to(self.device)

        cos_m_theta = pow(-1, n) * self.m_choose_n_map[2*n.long()] * cos_theta.pow(self.m - 2*n) * sin_square_theta.pow(n)
        cos_m_theta = torch.sum(cos_m_theta, 0)
        return cos_m_theta

    def forward(self, x, y=None):
        x_dot_wT = x.mm(self.w.transpose(0, 1))
        f_y_i = torch.tensor(x_dot_wT)
        if y is not None:
            batch_size = y.size(0)
            w_norm = self.w.norm(p=2, dim=1)
            x_norm = x.norm(p=2, dim=1)
            y_i = x_dot_wT.gather(1, y.view(-1, 1)).squeeze(dim=1)
            cos_theta = y_i / (x_norm * w_norm.index_select(0, y))
            cos_m_theta = self.evaluate_cos_m_theta(cos_theta)
            k = self.determine_k(cos_theta)
            idxs = torch.arange(0, batch_size, dtype=torch.long)
            if self.use_cuda:
                idxs = idxs.to(self.device)
            f_y_i[idxs, y] = ((self._lambda * y_i) + ((pow(-1, k) * cos_m_theta - 2*k) * x_norm * w_norm.index_select(0, y))) / (1 + self._lambda)
        return f_y_i

I couldn’t figure out the problem.

I could train this longer before hitting NAN with double tensor, which suggests that it has to be numerical problem.

Before I hit NAN, the loss will decrease in the double case.

The running code is as follows which requires mnist:

from __future__ import print_function
import time
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
from experiment import LargeMarginSoftmaxLinear

torch.set_printoptions(precision=20)

tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
for i in range(len(tableau20)):
    r, g, b = tableau20[i]
    tableau20[i] = (r / 255., g / 255., b / 255.)

# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=256, metavar='N',
                    help='input batch size for training (default: 256)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=30, metavar='N',
                    help='number of epochs to train (default: 30)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--weight_decay', type=float, default=0.0005, metavar='W',
                    help='SGD weight decay (default: 0.0005)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--vis_path', type=str, default="visualizations/color6", metavar='S',
                    help='path to save your visualization figures')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

torch.backends.cudnn.benchmark = True if use_cuda else False


train_data, train_labels = torch.load("../data/processed/training.pt")
randperm = torch.randperm(train_data.size(0))
train_data = train_data[randperm]
train_labels = train_labels[randperm]
train_data = train_data.float()
train_data.div_(255)
train_data.unsqueeze_(1)
train_data.sub_(0.1307).div_(0.3081)
train_data = train_data.to(device)
train_labels = train_labels.to(device)

test_data, test_labels = torch.load("../data/processed/test.pt")
test_data = test_data.float()
test_data.div_(255)
test_data.unsqueeze_(1)
test_data.sub_(0.1307).div_(0.3081)
test_data = train_data.to(device)
test_labels = train_labels.to(device)

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.Linear):
        nn.init.constant_(m.bias, 0)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 256)
        self.fc2 = nn.Linear(256, 2)
        self.fc3 = LargeMarginSoftmaxLinear(2, 10, 3, 0, use_cuda, device)

    def forward(self, x, y=None):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2, stride=2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2, stride=2)
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        features = self.fc2(x)
        x = self.fc3(features, y)
        return x, features

model = Net().to(device)
model.apply(weights_init)

criterion = nn.CrossEntropyLoss().to(device)

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

batch_size = args.batch_size
test_batch_size = args.test_batch_size

def train(epoch):
    model.train()
    for batch_idx in range(0, train_data.size(0), batch_size):
        batch_data = train_data[batch_idx: batch_idx+batch_size]
        batch_labels = train_labels[batch_idx: batch_idx+batch_size]
        optimizer.zero_grad()
        x, features = model(batch_data, batch_labels)
        loss = criterion(x, batch_labels)
        loss.backward()
        optimizer.step()
        if (batch_idx / batch_size) % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx, train_data.size(0),
                100. * batch_idx / train_data.size(0), loss.item()))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx in range(0, test_data.size(0), test_batch_size):
            batch_data = test_data[batch_idx: batch_idx+test_batch_size]
            batch_labels = test_labels[batch_idx: batch_idx+test_batch_size]
            x, features = model(batch_data)
            batch_scores = F.log_softmax(x, dim=1)
            test_loss += F.nll_loss(batch_scores, batch_labels, size_average=False).item() # sum up batch loss
            pred = batch_scores.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(batch_labels.view_as(pred)).sum().item()

    test_loss /= test_data.size(0)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, test_data.size(0),
        100. * correct / test_data.size(0)))

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
for epoch in range(1, args.epochs + 1):
    scheduler.step()
    torch.cuda.synchronize()
    start = time.perf_counter()
    train(epoch)
    torch.cuda.synchronize()
    end = time.perf_counter()
    print('Total time taken: {:.4f}s'.format(end-start))
    test()

From my experience NaN usually happens when you divide by 0 (or a small enough number). I see in your code that you divide y_i by a norm, it could be there.
The other solution is to use a debugger (or many print) and try to find a case where you get NaNs to determine during which operation it happens.

I have checked and the last iteration before NAN occurs does not have 0 in the denominator of the line you just mentioned.

In fact, there are no NAN values in forward before NAN occurs. Therefore, I suspect it happens in backward when gradient is computed. Not sure how to debug this!

Ok. The problem is with the norm. Apparently, torch norm function is not numerically stable for very small values (->0).