How can I modify a resnet or VGG network to use grayscale images. I am loading the network the following way
m=torchvision.models.segmentation.fcn_resnet50(pretrained=False, progress=True, num_classes=2, aux_loss=None)
Is there some way I can tweak this model after loading it?
Thanks
1 Like
You can always define a custom resnet and change the first layer to adapt for your input shape. For example, with resnet18:
import torch
import torchvision
resnet = torchvision.models.resnet18()
input = torch.randn((16,3,244,244))
output = resnet(input)
print(output.shape)
# this fails becasue resnet expects 3 channels
#input = torch.randn((16,1,244,244))
#output = resnet(input)
#print(output.shape)
import torch.nn as nn
class MyResNet(nn.Module):
def __init__(self, in_channels=1):
super(MyResNet, self).__init__()
# bring resnet
self.model = torchvision.models.resnet18()
# original definition of the first layer on the renset class
# self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
# your case
self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
def forward(self, x):
return self.model(x)
my_resnet = MyResNet()
# Now this works just fine !
input = torch.randn((16,1,244,244))
output = my_resnet(input)
print(output.shape)
Now, you could try the same for your model. If you want to dive in on how this models are built, I suggest you to go to https://github.com/pytorch/vision/tree/master/torchvision/models. You can learn a lot on how to build networks just by going over the different models.
I hope this helps !
6 Likes