Hello,
suppose that a model consists of multiple components, i.e, object recognition model. A pseudo-model might look like es follows:
import torch
class FooModel(torch.nn.Module):
def __init__(self):
super(FooModel, self).__init__()
# feature extraction
self.backbone = ResNet101()
# region proposal
self.rpn = RegionProposalNetwork()
# object detection
self.localization = LocalizationModel()
# object classification
self.classifier = Classifier()
def forward(self, inputs, targets)
feats = self.backbone(inputs)
rpn_loss_cls, rpn_loss_box, proposals = self.rpn(feats)
roi_loss, box_feats = self.localization(feats, proposals)
cls_loss = self.Classifier(box_feats, targets)
return rpn_loss_cls, rpn_loss_box, roi_loss, cls_loss
Now I use param.requires_grad = False
to freeze some layers in order to fine-tuning parameters. In this case, the gradients of those frozen layers are not calculated anymore. However, the losses are continuously calculated. A dummy fine-tuning strategy is defined below:
def FineTune(model, data_loader):
for epoch in range(max_epoch):
if epoch < fine_tune_backbone:
# freeze other models
self.rpn.freeze_for_tune()
self.localization.freeze_for_tune()
self.classifier.freeze_for_tune()
# unfreeze for fine-tuning
self.backbone.unfreeze_me()
elif epoch > fine_tune_backbone and epoch < fine_tune_rpn:
self.rpn.unfreeze_me()
elif epoch > fine_tune_rpn and eopch < fine_tune_roi:
self.roi.unfreeze_me()
self.localization.unfreeze_me()
elif epoch > fine_tune_roi and eopch < fine_tune_cls:
self.backbone.freeze_for_tune()
self.rpn.freeze_for_tune()
self.localization.freeze_for_tune()
self.classifier.unfreeze_me()
else:
# training all together without backbone
self.rpn.unfreeze_me()
self.localization.unfreeze_me()
inputs, targets = data_loader.get_batch()
rpn_loss_cls, rpn_loss_box, roi_loss, cls_loss = model.forward(inputs, targets)
....
The question here is: how to fine-tune each part of this model. Am I doing it correctly or is there some canonical approach that PyTorch provides? If you have a look at some papers, there wrote just like we fine-tuned xxx model after y-th epochs
. How does the implementation of fine-tuning look like in general? If some components got fixed, should I also change this submodel in inference mode (like, evaluation). Thanks for any inputs.