Pre-trained network demo

I’m trying to get a pre-trained network to output something that makes sense… but I’m having some troubles.
Here’s the snippet.

# get input image
import os
file_name = '26132.jpg'
if not os.access(file_name, os.R_OK):
    file_URL = ''
    os.system('wget ' + file_URL)
img =

# get model
import torchvision
resnet_18 = torchvision.models.resnet18(pretrained=True)

# get classes
file_name = 'synset_words.txt'
if not os.access(file_name, os.W_OK):
    synset_URL = ''
    os.system('wget ' + synset_URL)
classes = list()
with open(file_name) as class_file:
    for line in class_file:
        classes.append(line.strip().split(' ', 1)[1].split(', ', 1)[0])
classes = tuple(classes)

# define image transformation
from torchvision import transforms as trn
centre_crop = trn.Compose([
        trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# get top 5 probabilities
from torch.autograd import Variable as V
from torch.nn import functional as f
x = V(centre_crop(img).unsqueeze(0), volatile=True)
logit = resnet_18.forward(x)
h_x = f.softmax(logit).data.squeeze()
probs, idx = h_x.sort(0, True)
for i in range(0, 5):
    print('{:.3f} -> {}'.format(probs[i], classes[idx[i]]))

And this is the output

0.009 -> bucket
0.007 -> plunger
0.006 -> hook
0.005 -> water bottle
0.005 -> water jug

which should be, instead, roughly

0.99 -> German shepherd
0.01 -> malinois
0.00 -> Norwegian elkhound
0.00 -> Leonberg
0.00 -> red wolf

And this is the input picture.

It’s missing a call to resnet_18.eval(). Otherwise, it’s in training mode and batch normalization behaves differently.

Pretrained networks saved in train mode? :cold_sweat:
OK, this was hard to guess. :confused:

Sweet. Something works today, finally. Phew.
Thank you, @colesbury.

0.935 -> German shepherd
0.033 -> Leonberg
0.031 -> malinois
0.000 -> Norwegian elkhound
0.000 -> African hunting dog

What do you think, shall we embed the classes tuple, input image size, and the normalisation settings in the network object and make it in eval mode by default?
This code above should really be a oneliner, IMO.

At least I should send a PR with this example to be included in torchvision documentation, right?

For whom might be interested, I put a notebook with working code here.

1 Like