I just started learning Pytorch and I do not have good programming skills. I am trying to perform a segmentation task and given the limited amount of training samples, I have opted for transfer learning approach. I have managed to run the following model with my own data and the result is very bad.
class fcn(nn.Module):
def __init__(self, num_classes):
super(fcn, self).__init__()
self.stage1 = nn.Sequential(*list(pretrained_net.children())[:-4])
# change input channels to 9
self.stage1[0] = nn.Conv2d(9, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.stage2 = list(pretrained_net.children())[-4]
self.stage3 = list(pretrained_net.children())[-3]
self.scores1 = nn.Conv2d(512, num_classes, 1)
self.scores2 = nn.Conv2d(256, num_classes, 1)
self.scores3 = nn.Conv2d(128, num_classes, 1)
self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, 16, 8, 4, bias=False)
self.upsample_4x = nn.ConvTranspose2d(num_classes, num_classes, 4, 2, 1, bias=False)
self.upsample_2x = nn.ConvTranspose2d(num_classes, num_classes, 4, 2, 1, bias=False)
def forward(self, x):
x = self.stage1(x)
s1 = x # 1/8
x = self.stage2(x)
s2 = x # 1/16
x = self.stage3(x)
s3 = x # 1/32
s3 = self.scores1(s3)
s3 = self.upsample_2x(s3)
s2 = self.scores2(s2)
s2 = s2 + s3
s1 = self.scores3(s1)
s2 = self.upsample_4x(s2)
s = s1 + s2
s = self.upsample_8x(s2)
return s
num_classes = 2 #len(classes)
pretrained_net = models.resnet18(pretrained=True)
model = fcn(num_classes)
With this same configuration, the model could only run with Resnet18 and Resnet32. When I tried other network structures such as Resnet50, VGG, Mobilenet. I got the following error
self.stage1[0] = nn.Conv2d(9, 64, kernel_size=7, stride=2, padding=3, bias=False)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\container.py", line 71, in __setitem__
key = self._get_item_by_idx(self._modules.keys(), idx)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\container.py", line 60, in _get_item_by_idx
raise IndexError('index {} is out of range'.format(idx))
IndexError: index 0 is out of range
Please, how can I solve this problem? How can I try a transfer learning framework exploring different network structures for a segmentation task? Any suggestions and comments would be highly appreciated.