Output of ResNet-18 is NaN with AMP

Hello,
I want to use AMP on a ResNet-18 which was trained without AMP (plain Float32) on CIFAR-10.
However, when I wrap the forward pass of the model in a torch.cuda.amp.autocast() block, the output of the network gets nan. When I deactivate AMP with torch.cuda.amp.autocast(enabled=False) I get the expected output values.
Below I attached a minimal example that reproduces my problem.

Does anyone have an idea why the output is always nan?

import torch

from models.cifar10_models.resnet import ResNet18

resnet = ResNet18(num_classes=10)
resnet.load_state_dict(torch.load('./pretrained_models/resnet18_cifar10.pth'))

inputs = torch.randn((124, 3, 32, 32), requires_grad=True).cuda()

resnet.cuda()
resnet.eval()
for i in range(500):
    with torch.cuda.amp.autocast():
        outputs = resnet(inputs)

print(outputs)

Edit:
I have noticed that reproducing the problem is difficult when I am loading a local model. Therefore, here another example with a pretrained ResNet-18 on ImageNet. Running the code below also produces nan as output:

import torch
from torchvision.models.resnet import resnet18

resnet = resnet18(pretrained=True)
resnet.fc = torch.nn.Linear(512, 10)

inputs = torch.randn((124, 3, 32, 32), requires_grad=False).cuda()

resnet.cuda()
resnet.eval()
with torch.cuda.amp.autocast(enabled=True):
    outputs = resnet(inputs)

print(outputs)

Unfortunately I cannot reproduce the issue.
No matter how often I try I never get nan
This might be a version problem.

It tried it on:
CUDA 10.2 / Pytorch1.8.1
as well as on:
CUDA 10.2 / Pytorch1.9.0

What are you using?

1 Like

Oh, you are right. If I am using the pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel docker image to run the example the outputs are fine.

In my environment I have installe torch==1.9.0, torchvision=0.10.0 and cudotoolkit=11.1 using conda. Seems like the problem with the NaN-Output only happens when using Cuda 11.1. When I install cudatoolkit=10.2 everything works fine.

Thanks for your help! Very much appreciated :smiley: :+1:

1 Like

Which GPU are you using to create the NaN outputs in PyTorch 1.9.0?

I am using a GTX 1650.