How can I replace the forward method of a predefined torchvision model with my customized forward function?

How can I replace the forward method of a predefined torchvision model with my customized forward function?

I tried the following:

Each layer within the resnet model has its own forward function, hence you would need to apply a change to the forward method explicitly to each layer

What kind of change do I have to make to the sublayer forwardfunctions if they are not to be changed at all? Are the forward functions of the sublayers overwritten by this action? what would the solution look like in the code?

What do you want to do? do you want to add some new layers to a pretrained model(resnet34) and then in your forward() do as you wish?
it would help much more if you can say what exactly you are after

In this case I just want to delete the self.avgpool, the flatten operation and the fc layer, so that the output of the model is just the [bs x 512 x 7 x 7] Featuremaps (when 3 224 224 image tensors are feeded), as you can see in the image in the appendix. But beyond this case i am interested in general how i can modify predefined torchvision models.

You could derive a custom class using the resnet class as its parent:

import torchvision.models as models
from torchvision.models.resnet import ResNet, BasicBlock

class MyResNet18(ResNet):
    def __init__(self):
        super(MyResNet18, self).__init__(BasicBlock, [2, 2, 2, 2])
        
    def forward(self, x):
        # change forward here
        x = self.conv1(x)
        return x


model = MyResNet18()
# if you need pretrained weights
model.load_state_dict(models.resnet18(pretrained=True).state_dict())
7 Likes

aside from the solution kindly provided by @ptrblck , you can also do sth like this :

class MyCustomResnet18(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        
        resnet18 = models.resnet18(pretrained=pretrained)
        # here we get all the modules(layers) before the fc layer at the end
        # note that currently at pytorch 1.0 the named_children() is not supported
        # and using that instead of children() will fail with an error
        self.features = nn.ModuleList(resnet18.children())[:-1]
        # Now we have our layers up to the fc layer, but we are not finished yet 
        # we need to feed these to nn.Sequential() as well, this is needed because,
        # nn.ModuleList doesnt implement forward() 
        # so you cant do sth like self.features(images). Therefore we use 
        # nn.Sequential and since sequential doesnt accept lists, we 
        # unpack all the items and send them like this
        self.features = nn.Sequential(*self.features)
        # now lets add our new layers 
        in_features = resnet18.fc.in_features
        # from now, you can add any kind of layers in any quantity!  
        # Here I'm creating two new layers 
        self.fc0 = nn.Linear(in_features, 256)
        self.fc0_bn = nn.BatchNorm1d(256, eps = 1e-2)
        self.fc1 = nn.Linear(256, 256)
        self.fc1_bn = nn.BatchNorm1d(256, eps = 1e-2)
        
        # initialize all fc layers to xavier
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_normal_(m.weight, gain = 1)

    def forward(self, input_imgs):
       # now in forward pass, you have the full control, 
       # we can use the feature part from our pretrained model  like this
        output = self.features(input_imgs)
        # since we are using fc layers from now on, we need to flatten the output.
        # we used the avgpooling but we still need to flatten from the shape (batch, 1,1, features)
        # to (batch, features) so we reshape like this. input_imgs.size(0) gives the batchsize, and 
        # we use -1 for inferring the rest
        output = output.view(input_imgs.size(0), -1)
       # and also our new layers. 
        output = self.fc0_bn(F.relu(self.fc0(output)))
        output = self.fc1_bn(F.relu(self.fc1(output)))
                
        return output

You can get fancy and add new methods for your network (e.g. for freezing, unfreezing different parts of your network, that can come handy in finetuning)

7 Likes

So is it impossible to view what the actual forward function description of a pretrained torch vision model… I facing a problem with that …
In a pretrained model, when I view the description, I will be able to see all the data members of the model defined i.e individual layers and its parameters here. But for a model, all residual connections are and its operation are defined in the forward function which the pretrained model will not show. Is there any other way by which I an get the information about where the residual connections are present in the pre-trained model ??

You would have to check the source code, e.g. here for torchvision.models.

An alternative to @ptrblck 's solution is this:

import torchvision.models as models
import torch


def new_forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    return x


# define a resnet instance
resent = models.resnet18()

# add new_forward function to the resnet instance as a class method
bound_method = new_forward.__get__(resent, resent.__class__)
setattr(resent, 'forward', bound_method)

# you can also remove the 2 layers resent.avgpool and resent.fc because you are not using them in the new forward method
delattr(resent, 'avgpool')
delattr(resent, 'fc')

# call the new forward method
inputs = torch.rand(1, 3, 224, 224)
outputs = resent(inputs)

print('type(resent) = ', type(resent))
print('type(resent.forward) = ', type(resent.forward))
print('outputs.shape = ', outputs.shape)