Changing built-in ResNet50 model to 1 channel images - how to set transforms.Normalize([...])?

I’m attempting to use the PyTorch built-in ResNet 50 model from https://pytorch.org/docs/stable/torchvision/models.html with single-channel (grayscale) images.

I figured out from various posts on this forum that I needed to change my model setup like so:

# MyResNet50

import torchvision
import torch.nn as nn

def buildResNet50Model(numClasses):
    # get the stock PyTorch ResNet50 model w/ pretrained set to True
    model = torchvision.models.resnet50(pretrained = True)

    # freeze all model parameters so we don’t backprop through them during training (except the FC layer that will be replaced)
    for param in model.parameters():
        param.requires_grad = False
    # end for
    
    # !!!!!! this line is specific to the 3 channel to one channel change, other lines in this function are the same as before !!!!!!
    # change 1st conv layer from 3 channel to 1 channel
    model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

    # the last (fully connected) layer per the number of classes
    # first, get/save the current number of input features to the fc layer
    numFcInputs = model.fc.in_features
    # now replace the fc layer with minor changes, and using our number of classes
    model.fc = nn.Sequential(nn.Linear(numFcInputs, 256),
                             nn.ReLU(),
                             nn.Dropout(0.2),
                             nn.Linear(256, numClasses))

    return model
# end function

Then when I went to train again I got this error:

RuntimeError: output with shape [1, 224, 224] doesn’t match the broadcast shape [3, 224, 224]

I figured out from this post python - RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224] - Stack Overflow that I need to change the transforms.Normalize part of the transform from 3 numbers to 1. Based on various PyTorch examples I’m using this for a transform:

TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(size=224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

so it seems I need to change the transforms.Normalize line to:

transforms.Normalize([x], [y])

Where x and y are floating point numbers, at which point the network then trains successfully.

My question is, how do I determine x and y? Here are some possibilities I could think of:

# use 1st number
transforms.Normalize([0.485], [0.229])
# use middle number
transforms.Normalize([0.456], [0.224])
# use average
transforms.Normalize([0.499], [0.226])
# use nice round numbers
transforms.Normalize([0.5], [0.25])
# use 0.5 for both
transforms.Normalize([0.5], [0.5])

Since the origin of the of the original six numbers [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] is not especially clear (see Origin of the means and stds used for preprocessing? · Issue #1439 · pytorch/vision · GitHub) I’m not sure how to go about determining which of the above (or something entirely different) is best, other than trial and error. Any suggestions?

You could calculate the stats from your current training dataset, as was done for the ImageNet data for the posted values (or just use the 0.5 “defaults”).

Thanks for the update, it looks like I should follow these posts to calculate stats: