Grayscale images for resenet and deeplabv3

How can I modify a resnet or VGG network to use grayscale images. I am loading the network the following way

m=torchvision.models.segmentation.fcn_resnet50(pretrained=False, progress=True, num_classes=2, aux_loss=None)

Is there some way I can tweak this model after loading it?

Thanks

1 Like

You can always define a custom resnet and change the first layer to adapt for your input shape. For example, with resnet18:

import torch 
import torchvision

resnet = torchvision.models.resnet18()

input = torch.randn((16,3,244,244))
output = resnet(input)
print(output.shape)

# this fails becasue resnet expects 3 channels

#input = torch.randn((16,1,244,244))
#output = resnet(input)
#print(output.shape)

import torch.nn as nn

class MyResNet(nn.Module):

    def __init__(self, in_channels=1):
        super(MyResNet, self).__init__()

        # bring resnet
        self.model = torchvision.models.resnet18()

        # original definition of the first layer on the renset class
        # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        # your case
        self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

    def forward(self, x):
        return self.model(x)

my_resnet = MyResNet()

# Now this works just fine !

input = torch.randn((16,1,244,244))
output = my_resnet(input)
print(output.shape)

Now, you could try the same for your model. If you want to dive in on how this models are built, I suggest you to go to https://github.com/pytorch/vision/tree/master/torchvision/models. You can learn a lot on how to build networks just by going over the different models.

I hope this helps !

5 Likes