How can I replace the forward method of a predefined torchvision model with my customized forward function?

aside from the solution kindly provided by @ptrblck , you can also do sth like this :

class MyCustomResnet18(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        
        resnet18 = models.resnet18(pretrained=pretrained)
        # here we get all the modules(layers) before the fc layer at the end
        # note that currently at pytorch 1.0 the named_children() is not supported
        # and using that instead of children() will fail with an error
        self.features = nn.ModuleList(resnet18.children())[:-1]
        # Now we have our layers up to the fc layer, but we are not finished yet 
        # we need to feed these to nn.Sequential() as well, this is needed because,
        # nn.ModuleList doesnt implement forward() 
        # so you cant do sth like self.features(images). Therefore we use 
        # nn.Sequential and since sequential doesnt accept lists, we 
        # unpack all the items and send them like this
        self.features = nn.Sequential(*self.features)
        # now lets add our new layers 
        in_features = resnet18.fc.in_features
        # from now, you can add any kind of layers in any quantity!  
        # Here I'm creating two new layers 
        self.fc0 = nn.Linear(in_features, 256)
        self.fc0_bn = nn.BatchNorm1d(256, eps = 1e-2)
        self.fc1 = nn.Linear(256, 256)
        self.fc1_bn = nn.BatchNorm1d(256, eps = 1e-2)
        
        # initialize all fc layers to xavier
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_normal_(m.weight, gain = 1)

    def forward(self, input_imgs):
       # now in forward pass, you have the full control, 
       # we can use the feature part from our pretrained model  like this
        output = self.features(input_imgs)
        # since we are using fc layers from now on, we need to flatten the output.
        # we used the avgpooling but we still need to flatten from the shape (batch, 1,1, features)
        # to (batch, features) so we reshape like this. input_imgs.size(0) gives the batchsize, and 
        # we use -1 for inferring the rest
        output = output.view(input_imgs.size(0), -1)
       # and also our new layers. 
        output = self.fc0_bn(F.relu(self.fc0(output)))
        output = self.fc1_bn(F.relu(self.fc1(output)))
                
        return output

You can get fancy and add new methods for your network (e.g. for freezing, unfreezing different parts of your network, that can come handy in finetuning)

7 Likes