I’m a beginning PyTorch user, and I’m trying to get a sanity check working. I want to run a JPEG image through the pre-trained ResNet model and classify it according to ImageNet labels. However, the classifications are wildly inaccurate. Trying to classify a golden retriever seems to be all over the map. Currently running this script yields:
I’m wondering if there’s something I’m doing wrong when I normalize the JPEG, or am I not processing the ResNet outputs correctly? I am attaching my short script: it includes the links where I downloaded a test picture as well as the JSON file I used to search for class labels. (I double-checked this file with several other ImageNet label files - I think it is the correct mapping from indexes to labels).
Here are some other resources I have consulted / incorporated into my script:
Imagenet classes (for labels)
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py (for normalization)
Any help would be greatly appreciated!
import torch as torch import torchvision.transforms as transforms from torchvision.models import resnet152 from PIL import Image import json as json ## Import ImageNet Class Labels # http://files.fast.ai/models/imagenet_class_index.json labels = json.load(open('imagenet_class_index.json')) ## Get an image and normalize it # https://upload.wikimedia.org/wikipedia/commons/8/82/Golden_Retriever_standing_Tucker.jpg im = Image.open('images/golden.jpg') im.show() #normalization for ImageNet taken from Torch examples on Github normalize = transforms.Compose( [transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) ]) #sanity check im_tensor = normalize(im) to_pil = transforms.ToPILImage() to_pil(im_tensor).show() ## Classify the image net = resnet152(pretrained=True) outputs = net(im_tensor.unsqueeze(0)) for idx in outputs.sort()[-5:]: print(idx.item(), labels[(str(idx.item()))]) # _ , pred = torch.max(outputs,1) # class_idx = pred.data.numpy().argmax() # label = labels[str(class_idx)] # print(label)