Extracting feature vector for grey images via ResNet18: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]

I have 600x800 images that have only 1 channel. I am trying to use pre-trained ResNet18 to extract their feature vectors, however the code expects 3 channel:

import torch
import torchvision
import torchvision.models as models
from PIL import Image

img = Image.open("labeled-data/train_moth/moth/frame163.png")


# Load the pretrained model
model = models.resnet18(pretrained=True)

# Use the model object to select the desired layer
layer = model._modules.get('avgpool')

# Set model to evaluation mode
model.eval()

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def get_vector(image):
    # Create a PyTorch tensor with the transformed image
    t_img = transforms(image)
    t_img = torch.cat((t_img, t_img, t_img), 0)
    # Create a vector of zeros that will hold our feature vector
    # The 'avgpool' layer has an output size of 512
    my_embedding = torch.zeros(512)

    # Define a function that will copy the output of a layer
    def copy_data(m, i, o):
        my_embedding.copy_(o.flatten())                 # <-- flatten

    # Attach that function to our selected layer
    h = layer.register_forward_hook(copy_data)
    # Run the model on our transformed image
    with torch.no_grad():                               # <-- no_grad context
        model(t_img.unsqueeze(0))                       # <-- unsqueeze
    # Detach our copy function from the layer
    h.remove()
    # Return the feature vector
    return my_embedding

Here’s the error I am getting:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-59ab45f8c1e6> in <module>
     42 
     43 
---> 44 pic_vector = get_vector(img)

<ipython-input-5-59ab45f8c1e6> in get_vector(image)
     21 def get_vector(image):
     22     # Create a PyTorch tensor with the transformed image
---> 23     t_img = transforms(image)
     24     t_img = torch.cat((t_img, t_img, t_img), 0)
     25     # Create a vector of zeros that will hold our feature vector

~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
     59     def __call__(self, img):
     60         for t in self.transforms:
---> 61             img = t(img)
     62         return img
     63 

~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, tensor)
    210             Tensor: Normalized Tensor image.
    211         """
--> 212         return F.normalize(tensor, self.mean, self.std, self.inplace)
    213 
    214     def __repr__(self):

~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py in normalize(tensor, mean, std, inplace)
    296     if std.ndim == 1:
    297         std = std[:, None, None]
--> 298     tensor.sub_(mean).div_(std)
    299     return tensor
    300 

RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
    
    pic_vector = get_vector(img)



Code is from: https://stackoverflow.com/a/63552285/2414957

I thought using

t_img = torch.cat((t_img, t_img, t_img), 0)

would be helpful but I was wrong.

Here’s a bit about image:

$ identify frame163.png 
frame163.png PNG 800x600 800x600+0+0 8-bit Gray 256c 175297B 0.000u 0:00.000

This part of the error suggests the problem is in your use of torchvision.transforms.Normalize.

–> 212 return F.normalize(tensor, self.mean, self.std, self.inplace)

The docs for that name are here: https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Normalize.

That page describes its arguments as:

  • mean (sequence) – Sequence of means for each channel.
  • std (sequence) – Sequence of standard deviations for each channel.

In your code, you gave 3 means and 3 standard deviations to Normalize, which it will try to use for 3 different channels:
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

Since your inputs only have 1 channel, you should only be passing lists of 1 mean and 1 standard deviation like this:
torchvision.transforms.Normalize(mean=[0.485], std=[0.229])