Hello,
I’m a beginning PyTorch user, and I’m trying to get a sanity check working. I want to run a JPEG image through the pre-trained ResNet model and classify it according to ImageNet labels. However, the classifications are wildly inaccurate. Trying to classify a golden retriever seems to be all over the map. Currently running this script yields:
733 pole
600 hook
700 paper_towel
463 bucket
852 tennis_ball
I’m wondering if there’s something I’m doing wrong when I normalize the JPEG, or am I not processing the ResNet outputs correctly? I am attaching my short script: it includes the links where I downloaded a test picture as well as the JSON file I used to search for class labels. (I double-checked this file with several other ImageNet label files - I think it is the correct mapping from indexes to labels).
Here are some other resources I have consulted / incorporated into my script:
Imagenet classes (for labels)
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py (for normalization)
Any help would be greatly appreciated!
import torch as torch
import torchvision.transforms as transforms
from torchvision.models import resnet152
from PIL import Image
import json as json
## Import ImageNet Class Labels
# http://files.fast.ai/models/imagenet_class_index.json
labels = json.load(open('imagenet_class_index.json'))
## Get an image and normalize it
# https://upload.wikimedia.org/wikipedia/commons/8/82/Golden_Retriever_standing_Tucker.jpg
im = Image.open('images/golden.jpg')
im.show()
#normalization for ImageNet taken from Torch examples on Github
normalize = transforms.Compose(
[transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])
#sanity check
im_tensor = normalize(im)
to_pil = transforms.ToPILImage()
to_pil(im_tensor).show()
## Classify the image
net = resnet152(pretrained=True)
outputs = net(im_tensor.unsqueeze(0))
for idx in outputs[0].sort()[1][-5:]:
print(idx.item(), labels[(str(idx.item()))][1])
# _ , pred = torch.max(outputs,1)
# class_idx = pred.data.numpy().argmax()
# label = labels[str(class_idx)][1]
# print(label)