Hello,
I am working with multi band input (for remote sensing) and I want to use a pre trained model.
I modified the first convolution of my model to handle 7 channels image.
Everything is working fine, except for an unknown reason the output dimensions are changed from 38x38 to 19x19.
When I run a simple fcn_resnet50 on my RGB image original dimensions are not changed. So I do not understand why it does know, since I kept the same conv parameters (stride and padding).
Here is the code for my model:
class TransferLearningModel(nn.Module):
def __init__(self, args):
super(TransferLearningModel, self).__init__()
self = self.float()
# creating first layer
self.convRGB = nn.Sequential(nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
nn.BatchNorm2d(64,eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.convbands = nn.Sequential(nn.Conv2d(args.nb_channels-3, 64,kernel_size=(7,7), stride=(2,2), padding=(3, 3), bias=False),
nn.BatchNorm2d(64,eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.main = models.segmentation.fcn_resnet50(True)
self.main.requires_grad = False
self.convRGB[0].weight.data = torch.clone(self.main.backbone.conv1.weight.data)
self.convRGB[1].weight.data = torch.clone(self.main.backbone.bn1.weight.data)
# emptying the first resnet conv
self.main.backbone.conv1 = torch.nn.Identity()
self.main.backbone.bn1 = torch.nn.Identity()
self.main.backbone.relu = torch.nn.Identity()
self.convRGB.requires_grad = False
self.main.classifier[4] = nn.Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
self.main.classifier[4].requires_grad = True
#weight initialization
self.convbands[0].apply(self.init_weights)
self.main.classifier[4].apply(self.init_weights)
self.opti = optim.Adam(self.parameters())
def init_weights(self,layer):
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
def forward(self, input):
# passing through our layers
x1bands = self.convbands(input[:,3:,:,:])
x1RGB = self.convRGB(input[:,0:3,:,:])
x2 = self.main(x1bands+x1RGB)
out = x2["out"]
return out