Unexplained SGD behaviour

Can someone please shed some light on what’s going on here? Why do we get exploding loss leading to nan?


import torch
import numpy as np
import torchvision as tv
import torch.nn.functional as F

stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

train_loader = torch.utils.data.DataLoader(                                                                                                                                                  
    tv.datasets.CIFAR10(root='./data', train=False,
                            tv.transforms.Resize(64),                                                                                                                   tv.transforms.ToTensor(),                                                                                                                   tv.transforms.Normalize(*stats)                                                                                                                                  
                        ]), download=True),                                                                                                                                                  
    batch_size=512, shuffle=True,                                                                                                                                                            
    num_workers=4, pin_memory=False)

dataiter = next(iter(train_loader))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = tv.models.AlexNet(num_classes=10).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=5e-4, momentum=0.9, nesterov=True)

for epoch in range(150):
    for data, target in [dataiter] * 50:
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        outputs = F.log_softmax(outputs, dim=1)
        loss = F.nll_loss(outputs, target)
        print("loss: {}".format(loss.item()))
        probs, preds = torch.max(outputs, 1)

When did you meet this nan? I cannot reproduce it…

Seriously you can’t? So, for you the MWE works fine you don’t have any of those exploding loss issues? WTF? The nan came trying the above example and I think it’s super sensitive to the learning rate if you play a bit with it you might hit the hot spot sooner than later. I’m saying this cause I first met this nan when using a step scheduler and randomly everything started to explode. I was looking the same 10 lines of training code over and over again questioning myself whether the optimizer.zero_grad() should go before the loss or after, and all sorts of these crazy stuff.

@kirk86 I guess the learning rate is slightly too high. Do you still get NaNs with a lr of 0.01? You could test whether gradually warming up the learning rate is beneficial.

@kirk86 I played a bit with your script and indeed I can confirm that it is unexpected hard to get at quick convergence to near 0 loss for the single-test-batch. It immediately works for me with a significantly smaller learning rate like 1e-3 and small batch-sizes like 1,4 or 32 (no time to test others).

@andreaskoepf Thanks for the reply, yup with a lr of 0.01 it works fine but what throws me off is this insane sensitivity of the optimizer. If I switch to adam I don’t get nan but I get stagnation so loss is ~same.

What do we mean these days by warming up the learning rate, there are so many heuristics that’s its hard to keep track of them all.

@andreaskoepf Appreciate it! But it’s strange cause this way with smaller batch sizes and learning rates we loose all the benefits of gpus and fast training, so in a sense to me it seems that we’re back at square one.

On GPU it is normal to tune the lr. Too large lr can cause explosion.

I thought of gradually increasing the learning rate during a ramp-up period at the beginning of the training, e.g. with a learning rate scheduler like https://github.com/ildoonet/pytorch-gradual-warmup-lr

In general I would recommend to experiment with different optimizers, weight-initializations, activation-functions and learning rates. You collect loss-logs of the experiments and plot them together to see what works best. To further analyze things you could print out the max-abs values (or a norm) of the gradient.

Thanks Andreas,

If you don’t mind me asking does that mean taking the max abs val of the norm of all variables that have a requires_grad=True? Or just specific variables/layers with grads?