suppose that a model consists of multiple components, i.e, object recognition model. A pseudo-model might look like es follows:
import torch class FooModel(torch.nn.Module): def __init__(self): super(FooModel, self).__init__() # feature extraction self.backbone = ResNet101() # region proposal self.rpn = RegionProposalNetwork() # object detection self.localization = LocalizationModel() # object classification self.classifier = Classifier() def forward(self, inputs, targets) feats = self.backbone(inputs) rpn_loss_cls, rpn_loss_box, proposals = self.rpn(feats) roi_loss, box_feats = self.localization(feats, proposals) cls_loss = self.Classifier(box_feats, targets) return rpn_loss_cls, rpn_loss_box, roi_loss, cls_loss
Now I use
param.requires_grad = False to freeze some layers in order to fine-tuning parameters. In this case, the gradients of those frozen layers are not calculated anymore. However, the losses are continuously calculated. A dummy fine-tuning strategy is defined below:
def FineTune(model, data_loader): for epoch in range(max_epoch): if epoch < fine_tune_backbone: # freeze other models self.rpn.freeze_for_tune() self.localization.freeze_for_tune() self.classifier.freeze_for_tune() # unfreeze for fine-tuning self.backbone.unfreeze_me() elif epoch > fine_tune_backbone and epoch < fine_tune_rpn: self.rpn.unfreeze_me() elif epoch > fine_tune_rpn and eopch < fine_tune_roi: self.roi.unfreeze_me() self.localization.unfreeze_me() elif epoch > fine_tune_roi and eopch < fine_tune_cls: self.backbone.freeze_for_tune() self.rpn.freeze_for_tune() self.localization.freeze_for_tune() self.classifier.unfreeze_me() else: # training all together without backbone self.rpn.unfreeze_me() self.localization.unfreeze_me() inputs, targets = data_loader.get_batch() rpn_loss_cls, rpn_loss_box, roi_loss, cls_loss = model.forward(inputs, targets) ....
The question here is: how to fine-tune each part of this model. Am I doing it correctly or is there some canonical approach that PyTorch provides? If you have a look at some papers, there wrote just like
we fine-tuned xxx model after y-th epochs. How does the implementation of fine-tuning look like in general? If some components got fixed, should I also change this submodel in inference mode (like, evaluation). Thanks for any inputs.