You could use this as the base code to modify your forward
method for e.g. resnet50:
class MyResnet50(models.resnet.ResNet):
def __init__(self, pretrained=False):
# Pass default resnet50 arguments to super init
# https://github.com/pytorch/vision/blob/e130c6cca88160b6bf7fea9b8bc251601a1a75c5/torchvision/models/resnet.py#L260
super(MyResnet50, self).__init__(models.resnet.Bottleneck, [3, 4, 6, 3])
if pretrained:
self.load_state_dict(models.resnet50(pretrained=True).state_dict())
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
model = MyResnet50(pretrained=True)
x = torch.randn(2, 3, 224, 224)
output = model(x)