I want to use AMP on a ResNet-18 which was trained without AMP (plain Float32) on CIFAR-10.
However, when I wrap the forward pass of the model in a
torch.cuda.amp.autocast() block, the output of the network gets
nan. When I deactivate AMP with
torch.cuda.amp.autocast(enabled=False) I get the expected output values.
Below I attached a minimal example that reproduces my problem.
Does anyone have an idea why the output is always
import torch from models.cifar10_models.resnet import ResNet18 resnet = ResNet18(num_classes=10) resnet.load_state_dict(torch.load('./pretrained_models/resnet18_cifar10.pth')) inputs = torch.randn((124, 3, 32, 32), requires_grad=True).cuda() resnet.cuda() resnet.eval() for i in range(500): with torch.cuda.amp.autocast(): outputs = resnet(inputs) print(outputs)
I have noticed that reproducing the problem is difficult when I am loading a local model. Therefore, here another example with a pretrained ResNet-18 on ImageNet. Running the code below also produces
nan as output:
import torch from torchvision.models.resnet import resnet18 resnet = resnet18(pretrained=True) resnet.fc = torch.nn.Linear(512, 10) inputs = torch.randn((124, 3, 32, 32), requires_grad=False).cuda() resnet.cuda() resnet.eval() with torch.cuda.amp.autocast(enabled=True): outputs = resnet(inputs) print(outputs)