I’m attempting to use the PyTorch built-in ResNet 50 model from https://pytorch.org/docs/stable/torchvision/models.html with single-channel (grayscale) images.
I figured out from various posts on this forum that I needed to change my model setup like so:
# MyResNet50
import torchvision
import torch.nn as nn
def buildResNet50Model(numClasses):
# get the stock PyTorch ResNet50 model w/ pretrained set to True
model = torchvision.models.resnet50(pretrained = True)
# freeze all model parameters so we don’t backprop through them during training (except the FC layer that will be replaced)
for param in model.parameters():
param.requires_grad = False
# end for
# !!!!!! this line is specific to the 3 channel to one channel change, other lines in this function are the same as before !!!!!!
# change 1st conv layer from 3 channel to 1 channel
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# the last (fully connected) layer per the number of classes
# first, get/save the current number of input features to the fc layer
numFcInputs = model.fc.in_features
# now replace the fc layer with minor changes, and using our number of classes
model.fc = nn.Sequential(nn.Linear(numFcInputs, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, numClasses))
return model
# end function
Then when I went to train again I got this error:
RuntimeError: output with shape [1, 224, 224] doesn’t match the broadcast shape [3, 224, 224]
I figured out from this post python - RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224] - Stack Overflow that I need to change the transforms.Normalize part of the transform from 3 numbers to 1. Based on various PyTorch examples I’m using this for a transform:
TRANSFORM = transforms.Compose([
transforms.Resize((224, 224)),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
so it seems I need to change the transforms.Normalize
line to:
transforms.Normalize([x], [y])
Where x
and y
are floating point numbers, at which point the network then trains successfully.
My question is, how do I determine x and y? Here are some possibilities I could think of:
# use 1st number
transforms.Normalize([0.485], [0.229])
# use middle number
transforms.Normalize([0.456], [0.224])
# use average
transforms.Normalize([0.499], [0.226])
# use nice round numbers
transforms.Normalize([0.5], [0.25])
# use 0.5 for both
transforms.Normalize([0.5], [0.5])
Since the origin of the of the original six numbers [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
is not especially clear (see Origin of the means and stds used for preprocessing? · Issue #1439 · pytorch/vision · GitHub) I’m not sure how to go about determining which of the above (or something entirely different) is best, other than trial and error. Any suggestions?