How can I replace the forward method of a predefined torchvision model with my customized forward function?

You could derive a custom class using the resnet class as its parent:

import torchvision.models as models
from torchvision.models.resnet import ResNet, BasicBlock

class MyResNet18(ResNet):
    def __init__(self):
        super(MyResNet18, self).__init__(BasicBlock, [2, 2, 2, 2])
        
    def forward(self, x):
        # change forward here
        x = self.conv1(x)
        return x


model = MyResNet18()
# if you need pretrained weights
model.load_state_dict(models.resnet18(pretrained=True).state_dict())
10 Likes