Basic VGG16 with autocast Mixed Precision creates NAN gradients

Hello,

full code and link to Google Colab below.

I want to use a basic VGG 16 as a feature extractor. I use VGG 16 from torchvision.models and remove the FC layers and the Average Pooling layer.

Then I create a dummy input and target and use MSE loss. When enabling cuda.amp.autocast some of the gradients are immediatly either infinite or NAN. Disabling cuda.amp.autocast works fine and does not cause any NANs.

I noticed that I have the same problem with my ResNet50 but after 5 iterations the NANs in the gradients disappear. I assume this is because of

    scaler.step(optimizer)
    scaler.update()

But the NaNs don’t seem to disappear with the VGG16 model and after a while I get a NaN loss.

Here is the Google Colab link . You can set use_mixed_precision to either True or False in order to compare the results. Dont forget to set your runtime type to GPU.

Is this a problem with the VGG16 from torchvision.models or am I doing something wrong?

Full Code:

import torch
import torchvision.models as models
import torch.nn as nn

class VggExtractor(torch.nn.Module):
  def __init__(self):
    super(VggExtractor, self).__init__()
    self.vgg16 = models.vgg16(pretrained=False)
    self.feature_extractor_layers = list(self.vgg16.children())[:-2]                                  #remove FC layers and AVG pool
    self.stem = torch.nn.Sequential(*(list(self.feature_extractor_layers[0].children())[0:5]))
    self.layer1 = torch.nn.Sequential(*(list(self.feature_extractor_layers[0].children())[5:10]))
    self.layer2 = torch.nn.Sequential(*(list(self.feature_extractor_layers[0].children())[10:17]))
    self.layer3 = torch.nn.Sequential(*(list(self.feature_extractor_layers[0].children())[17:24]))
    self.layer4 = torch.nn.Sequential(*(list(self.feature_extractor_layers[0].children())[24:-1]))    # remove last maxpool
  def forward(self,x):
    x = self.stem(x)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    return x

def feedforward_and_loss(net, images, target):
  output = net(images)
  loss_function = nn.MSELoss(reduction="sum")
  loss = loss_function(output, target)
  return loss

def print_nan_found(net, i):
  nan_found=False
  for name, parameter in net.named_parameters():
    if nan_found:
      break
    if parameter.grad is not None:
      nan_found = torch.isnan(parameter.grad).any()
  print("iteration", i, " nan found ", nan_found.item())

net = VggExtractor()
net = net.cuda()
scaler = torch.cuda.amp.GradScaler()

use_mixed_precision = True # Set this to False in order to disable mixed precision
scaler = torch.cuda.amp.GradScaler()
torch.cuda.amp.autocast(enabled=use_mixed_precision)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.005)

for i in range(1000):
  optimizer.zero_grad()

  images = torch.ones(1,3,512,1760)
  target = torch.zeros(1,512,32,110)
  images = images.cuda()
  target = target.cuda()


  if use_mixed_precision: 
    with torch.cuda.amp.autocast():
      loss = feedforward_and_loss(net, images, target)
    scaled_loss = scaler.scale(loss)
    scaled_loss.backward()
    scaler.step(optimizer)
    scaler.update()
    print_nan_found(net, i)
  else:
      loss = feedforward_and_loss(net, images, target)
      loss.backward()
      optimizer.step()
      print_nan_found(net)


Greetings Rupert

1 Like

Thanks for the code.
If I set use_mixed_precision=False, the code also outputs NaNs, so it seems your training blows up:

iteration 0  nan found  False
Loss is  10132.611328125
iteration 1  nan found  True
Loss is  inf
iteration 2  nan found  True
Loss is  nan
iteration 3  nan found  True
Loss is  nan
iteration 4  nan found  True
Loss is  nan
iteration 5  nan found  True
Loss is  nan
iteration 6  nan found  True
Loss is  nan
iteration 7  nan found  True
Loss is  nan
iteration 8  nan found  True
Loss is  nan

I guess the learning rate is just too high (for AMP and vanilla FP32 training).
If you use reduction='mean' or lower the learning rate, the model seems to work fine in both modes (for at least 60 epochs, as I’ve stopped it afterwards).

1 Like