Pretrained resnet constant output

I was trying some experiments with pretrained resnets, but couldn’t get it to correctly predict some basic images. I’m used to fine-tuning networks using pytorch, but never used them “raw”. Inputting any image will always predict the same category (“bucket, pail”).

Is there anything I’m doing wrong there ?

from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn as nn
import numpy as np
from PIL import Image
import requests
from io import BytesIO

model = models.resnet18(pretrained=True, num_classes=1000)

trans = transforms.Compose([
    transforms.Scale(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # from http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
])

url = 'https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/' \
      'raw/596b27d23537e5a1b5751d2b0481ef172f58b539/imagenet1000_clsid_to_human.txt'

imagenet_classes = eval(requests.get(url).content)

images = [('cat', 'https://www.wired.com/wp-content/uploads/2015/02/catinbox_cally_by-helen-haden_4x31-660x495.jpg'),
          ('pomeranian', 'https://c.photoshelter.com/img-get/I0000q_DdkyvP6Xo/s/900/900/Pomeranian-Dog-with-Ball.jpg'),
          ('car', 'https://static.pexels.com/photos/24353/pexels-photo.jpg')]


for class_name, image_url in images:
    response = requests.get(image_url)
    im = Image.open(BytesIO(response.content))

    tens = Variable(trans(im))
    tens = tens.view(1, 3, 224, 224)
    preds = nn.LogSoftmax()(model(tens)).data.cpu().numpy()
    res = np.argmax(preds)
    print('true (likely) label:', class_name)
    print('predicted', imagenet_classes[res], '\n')
1 Like

you have to first divide im by 256.

I suspect:

im = Image.open(BytesIO(response.content))

is returning the image with values from 0 to 255, rather than [0, 1] that the rest of the code expects.

I thought the transforms.ToTensor() class took care of that.

The tensors appear, in any case to have values in [0, 1]. ([-1, 1] after normalization)

1 Like

Strangely, the exact same code using vgg11 pretrained net seems to accurately predict the correct classes. Any resnet, from 18 to 152. has constant predictions.

I am facing the same issue, with constant predictions. Were you able to solve this problem? If so, please help.

Also, I have another issue. I would like to use a pre-trained model to predict one of my own 28 classes. So I suppose I should add a “Softmax” layer. I did this using add_module. However, the output from the model is still not with Softmax applied.

I have no experience with PyTorch or any other Deep Learning frameworks. Could someone please help?

ResNets use batch normalisation, so you will need to call model.eval() once loaded, before passing your inputs in. VGG doesn’t and hence even when forgetting to call .eval() you can see the right answers.

11 Likes

I can confirm that this was the issue for me! :sweat_smile:
Thanks!

Kai,

This worked for me also. Finding this solution was way too hard. I did not see this in the documentation anywhere. I searched every which way until your solution came up. The PyTorch Models documentation needs to explain this. I could get things working in Keras but not Resnet50 in PyTorch.

Do I need to use the same model.eval() when I load a customer model based on ResNet50 with tuning or is this unique to the ImageNet version?

Thank you!

model.eval() sets the model into evaluation mode, i.e. BatchNorm and Dropout layers will behave differently. model.train() sets it again into training mode. It affects all nn.Modules generally and is not specific to a dataset.

If your custom Resnet uses some layers, which behave differently during training and evaluation, you should definitely use it.
In general, I would recommend to use it always.

1 Like

We’ve now updated the documentation for torchvision.models to indicate the use .train() and .eval(). This will therefore be visible by default with the next release.

1 Like

ptrblck and Kaixhin, thank you very much! Very helpful.