class MyModel(nn.Module):
def __init__(self, pretrained_model):
self.pretrained_model = pretrained_model
self.last_layer = ... # create layer
def forward(self, x):
return self.last_layer(self.pretrained_model(x))
pretrained_model = torchvision.models.resnet18(pretrained=True)
model = MyModel(pretrained_model)
19 Likes