Correct pre-processing for pretrained GoogLeNet

I would like to use torchvision.models.googlenet as a pre-trained backbone in a custom model. The inputs to my custom model (during training and evaluation) will already be cropped to 224 pixels. I need to understand how to apply the ImageNet mean and standard deviation values to my images before they are passed through the model.

Interestingly, the GoogLeNet class provides a _transform_input() method which sounds like it should handle all the pre-processing for you. When initializing with pretrained=True, the constructor automatically sets transform_input=True as well. However, in the docs example, a torchvision.transforms pipeline is used in addition to the model’s internal method. To me, this seems like we are performing the ImageNet transform on the input twice.

My questions are

  1. Why does the GoogLeNet implementation include an internal transform method when none of the other torchvision models (i.e. densenet, mobilenet, etc) have one?
  2. In the internal transform method, why are there several 0.5 factors used in the formula to transform the image. From my understanding, ImageNet transforms should be calculated as (pixel_value - mean) / std, but the current implementation does not follow this
  3. Is applying a “double transform” like the docs example really the correct way to preprocess input for this pretrained model?

The script below tests several different pre-processing approaches

  1. Using GoogLeNet._transform_inputs()
    a. No other additions
    b. Plus the docs torchvision.transforms pipeline
    c. Plus my manual implementation of ImageNet transforms
  2. Without using GoogLeNet._transform_inputs()
    a. No other additions
    b. Using the docs torchvision.transforms pipeline
    c. Using my manual implementation of ImageNet transforms
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.models import googlenet
from torchvision import transforms


def load_tensor(filename):
    rgb = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB)
    nchw = torch.tensor(rgb).permute(2, 0, 1).unsqueeze(0)
    normalized = nchw.float() / 255.0
    return normalized


def manual(filename):
    normalized = load_tensor(filename)
    return torch.stack((
        (normalized[:, 0, :, :] - 0.485) / 0.229,
        (normalized[:, 1, :, :] - 0.456) / 0.224,
        (normalized[:, 2, :, :] - 0.406) / 0.225), dim=1)


def torch_transforms(filename):
    input_image = Image.open(filename)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)
    return input_batch


idx2class = { ... }  # Fill in from https://gist.github.com/maraoz/388eddec39d60c6d52d4

# Run model with different pre-processing approaches
model = googlenet(pretrained=True)
model.eval()
filename = '/home/addison/Documents/openalpr/vehicle-classifier/dog-224.jpg'
model.transform_input = True
scores = {
    '1a. Internal': model(load_tensor(filename)),
    '1b. Internal + torch.transforms': model(torch_transforms(filename)),
    '1c. Internal + manual': model(manual(filename))}
model.transform_input = False
scores.update({
    '2a. No internal': model(load_tensor(filename)),
    '2b. No internal (torch.transforms)': model(torch_transforms(filename)),
    '2c. No internal (manual)': model(manual(filename))})

# Summarize results
max_str = max(len(k) for k in scores)
for method, tensor in scores.items():
    probs = F.softmax(tensor, dim=1)
    values, indices = torch.topk(probs, k=1)
    print(f'{method:<{max_str}}: {idx2class[indices[0][0].item()]} ({values[0][0] * 100:.2f}%)')

I use the dog image from the docs example and pre-cropped it to 224 pixels in an external image editing program (see end of post). This produces different confidences for the top class prediction

1a. Internal                      : Samoyed, Samoyede (97.31%)
1b. Internal + torch.transforms   : Samoyed, Samoyede (91.27%)
1c. Internal + manual             : Samoyed, Samoyede (97.37%)
2a. No internal                   : Samoyed, Samoyede (96.38%)
2b. No internal (torch.transforms): Samoyed, Samoyede (33.95%)
2c. No internal (manual)          : Samoyed, Samoyede (83.95%)

dog-224

1 Like

I encountered the same problem too when evaluating the model with Imagenet1k dataset.It seems that to achieve the accuracy claimed in torchvision.models.googlenet, a “double pre-process” is needed. I guess the way transform_input deals with input data may be related to the googlenet model itself, and it’s really confusing to me that few people mentioned this.

The internal transformation seems to revert the standard ImageNet-specific normalization and replace it by simply subtracting constant “mean” and dividing by constant “std”. Not sure what the last division by 0.5 stands for there. I found a commit where the contributor claimed that it was done to match the TensorFlow implementation. However, I don’t see this in the TF.

It is weird that 1b. is different to 1c. in your comparison. Probably the cropping…