Now let’s say we want to use the output of “fc7” (RoIHeads->box_head->fc7) as input to another custom layers, for example, one Linear 1024 -> 512, also want to use the pre-trained weight until the fc7 and just train the weights of the last added layer for some loss function. Since the last module is wrapped with the class is there any way for doing so?
I bring up the last sentence since for example consider the resnet152, by printing the modules, I learn that It contains “9” layers and the last layer is an FC, I omit the last layer and build another Sequential network then added another layer and train the model using the only parameter of the last layer.
For a quick experiment, I would register a foward hook to this particular layer, store the output activation and reuse them in another model outside of this FasterRCNN model.
However, if you want to properly change the workflow, I would recommend to derive a custom model and adapt the forward methods as you want them to be.
Hi, Thanks for your reply.
Is there any code-snippet available for doing so? for example in the second case when we derive an instance and modify the layers, how we could assign the weights of the pre-trained model in this case?
You could use this as the base code to modify your forward method for e.g. resnet50:
class MyResnet50(models.resnet.ResNet):
def __init__(self, pretrained=False):
# Pass default resnet50 arguments to super init
# https://github.com/pytorch/vision/blob/e130c6cca88160b6bf7fea9b8bc251601a1a75c5/torchvision/models/resnet.py#L260
super(MyResnet50, self).__init__(models.resnet.Bottleneck, [3, 4, 6, 3])
if pretrained:
self.load_state_dict(models.resnet50(pretrained=True).state_dict())
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
model = MyResnet50(pretrained=True)
x = torch.randn(2, 3, 224, 224)
output = model(x)