Trouble getting pretrained ResNet152 to classify images properly

Hello,

I’m a beginning PyTorch user, and I’m trying to get a sanity check working. I want to run a JPEG image through the pre-trained ResNet model and classify it according to ImageNet labels. However, the classifications are wildly inaccurate. Trying to classify a golden retriever seems to be all over the map. Currently running this script yields:

733 pole
600 hook
700 paper_towel
463 bucket
852 tennis_ball

I’m wondering if there’s something I’m doing wrong when I normalize the JPEG, or am I not processing the ResNet outputs correctly? I am attaching my short script: it includes the links where I downloaded a test picture as well as the JSON file I used to search for class labels. (I double-checked this file with several other ImageNet label files - I think it is the correct mapping from indexes to labels).

Here are some other resources I have consulted / incorporated into my script:
Imagenet classes (for labels)
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py (for normalization)

Any help would be greatly appreciated!

import torch as torch
import torchvision.transforms as transforms
from torchvision.models import resnet152
from PIL import Image
import json as json

## Import ImageNet Class Labels

# http://files.fast.ai/models/imagenet_class_index.json
labels = json.load(open('imagenet_class_index.json'))

## Get an image and normalize it

# https://upload.wikimedia.org/wikipedia/commons/8/82/Golden_Retriever_standing_Tucker.jpg
im = Image.open('images/golden.jpg')
im.show()
#normalization for ImageNet taken from Torch examples on Github
normalize = transforms.Compose(
    [transforms.Scale(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
    ])

#sanity check
im_tensor = normalize(im)
to_pil = transforms.ToPILImage()
to_pil(im_tensor).show()

## Classify the image

net = resnet152(pretrained=True)
outputs = net(im_tensor.unsqueeze(0))

for idx in outputs[0].sort()[1][-5:]:
    print(idx.item(), labels[(str(idx.item()))][1])

# _ , pred = torch.max(outputs,1)
# class_idx = pred.data.numpy().argmax()
# label = labels[str(class_idx)][1]
# print(label)

The pre-trained models are created in training mode.
Call net.eval() after loading the model and run it again.

Thanks so much, that fixed my issue.