I have been testing AMP inference on the Jetson Nano with various model architectures, on the CIFAR-10 test set for inference, batch size of 32, and have seen some varied results for the inference time. ResNet runs about 10-15% slower compared to float32, VGG 30% faster, while MobileNetV2 stays around the same inference time.
I have read that Jetson Nanos should not have much speedup due to the lack of Tensor Cores, but the difference in results is unexpected, especially the slowdown on ResNet.
I’ve tried using the profiler with a smaller subset due to lack of memory but the output does not seem to provide much information.
All the architectures have been tested in the same way, with the models and weights coming from here.
The code for testing:
import torch from torch.utils.data import DataLoader, Subset import time import torchvision.transforms as transforms from torchvision.datasets import CIFAR10 DEVICE = "cuda:0" def load_data(): testtransform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201)), ]) subset = list(range(0,1000)) testset = CIFAR10(".", train=False, download=True, transform=testtransform) testset = Subset(testset, subset) testloader = DataLoader(testset, batch_size=32) return testloader def test(net, testloader): correct, total= 0, 0 with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA], profile_memory=True, with_flops = True, use_cuda=True) as p: with torch.cuda.amp.autocast_mode.autocast(dtype=torch.float16): with torch.no_grad(): for data in testloader: images,labels = data.to(DEVICE), data.to(DEVICE) outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(p.key_averages().table(sort_by="cuda_time_total")) accuracy = correct / total return accuracy net = torch.hub.load(repo_or_dir="chenyaofo/pytorch-cifar-models", model="cifar10_resnet20", pretrained= True) net.eval() net = net.to(DEVICE) testloader = load_data() start = time.time() accuracy = test(net, testloader) end = time.time() print("Accuracy:",accuracy, "\nTime taken:", end-start)
The tests have been carried out on Python 3.8, PyTorch 1.10, torchvision 0.11, CUDA 10.2.