I have 600x800 images that have only 1 channel. I am trying to use pre-trained ResNet18 to extract their feature vectors, however the code expects 3 channel:
import torch
import torchvision
import torchvision.models as models
from PIL import Image
img = Image.open("labeled-data/train_moth/moth/frame163.png")
# Load the pretrained model
model = models.resnet18(pretrained=True)
# Use the model object to select the desired layer
layer = model._modules.get('avgpool')
# Set model to evaluation mode
model.eval()
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def get_vector(image):
# Create a PyTorch tensor with the transformed image
t_img = transforms(image)
t_img = torch.cat((t_img, t_img, t_img), 0)
# Create a vector of zeros that will hold our feature vector
# The 'avgpool' layer has an output size of 512
my_embedding = torch.zeros(512)
# Define a function that will copy the output of a layer
def copy_data(m, i, o):
my_embedding.copy_(o.flatten()) # <-- flatten
# Attach that function to our selected layer
h = layer.register_forward_hook(copy_data)
# Run the model on our transformed image
with torch.no_grad(): # <-- no_grad context
model(t_img.unsqueeze(0)) # <-- unsqueeze
# Detach our copy function from the layer
h.remove()
# Return the feature vector
return my_embedding
Here’s the error I am getting:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-5-59ab45f8c1e6> in <module>
42
43
---> 44 pic_vector = get_vector(img)
<ipython-input-5-59ab45f8c1e6> in get_vector(image)
21 def get_vector(image):
22 # Create a PyTorch tensor with the transformed image
---> 23 t_img = transforms(image)
24 t_img = torch.cat((t_img, t_img, t_img), 0)
25 # Create a vector of zeros that will hold our feature vector
~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
59 def __call__(self, img):
60 for t in self.transforms:
---> 61 img = t(img)
62 return img
63
~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, tensor)
210 Tensor: Normalized Tensor image.
211 """
--> 212 return F.normalize(tensor, self.mean, self.std, self.inplace)
213
214 def __repr__(self):
~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py in normalize(tensor, mean, std, inplace)
296 if std.ndim == 1:
297 std = std[:, None, None]
--> 298 tensor.sub_(mean).div_(std)
299 return tensor
300
RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
pic_vector = get_vector(img)
Code is from: https://stackoverflow.com/a/63552285/2414957
I thought using
t_img = torch.cat((t_img, t_img, t_img), 0)
would be helpful but I was wrong.
Here’s a bit about image:
$ identify frame163.png
frame163.png PNG 800x600 800x600+0+0 8-bit Gray 256c 175297B 0.000u 0:00.000