Hello,
I want to use AMP on a ResNet-18 which was trained without AMP (plain Float32) on CIFAR-10.
However, when I wrap the forward pass of the model in a torch.cuda.amp.autocast()
block, the output of the network gets nan
. When I deactivate AMP with torch.cuda.amp.autocast(enabled=False)
I get the expected output values.
Below I attached a minimal example that reproduces my problem.
Does anyone have an idea why the output is always nan
?
import torch
from models.cifar10_models.resnet import ResNet18
resnet = ResNet18(num_classes=10)
resnet.load_state_dict(torch.load('./pretrained_models/resnet18_cifar10.pth'))
inputs = torch.randn((124, 3, 32, 32), requires_grad=True).cuda()
resnet.cuda()
resnet.eval()
for i in range(500):
with torch.cuda.amp.autocast():
outputs = resnet(inputs)
print(outputs)
Edit:
I have noticed that reproducing the problem is difficult when I am loading a local model. Therefore, here another example with a pretrained ResNet-18 on ImageNet. Running the code below also produces nan
as output:
import torch
from torchvision.models.resnet import resnet18
resnet = resnet18(pretrained=True)
resnet.fc = torch.nn.Linear(512, 10)
inputs = torch.randn((124, 3, 32, 32), requires_grad=False).cuda()
resnet.cuda()
resnet.eval()
with torch.cuda.amp.autocast(enabled=True):
outputs = resnet(inputs)
print(outputs)