This is Aman Goyal. I am currently pursuing research in CMU and MSU.
I actually have a pre-trained model file which basically has 3 modules: backbone, neck and head. It is basically a deep learning based tracking method.
Now what I need to do is that I just want it’s backbone weights which is basically a Resnet-50 trained on BDD100K detection dataset.
Now is there any possibility of extraction of only backbone weights from the .pth model file which basically contains all modules trained in end-to-end fashion.
The .pth file contains a sorted dictionary. Keys are sorted the same way you called them while instantiating the model.
This basically means if you call model.backbone.state_dict().keys() you should get which ones correspond to the backbone.
Then you just get them from the pth file.
Hello @aman_goyal ,
If I understand correctly, you can extract provided model backbone weights by the following codes:
import torch
model = torch.load('qdtrack-frcnn_r50_fpn_12e_bdd100k-13328aed.pth')
new_state_dict = {}
for name, weight in model['state_dict'].items():
if 'backbone' in name:
new_state_dict[name] = weight
This new state dict contains the end-to-end trained weights. you can create a new model with this state dict and add a new neck or head I guess.
@sagorsarker Hey, so I am slightly confused with the 2 code snippets now. So how should they be combined to get backbones.pth ? Could you please combine them in a single comment for more clarity.
Hi @aman_goyal ,
Here is the full code snippet about how I tried.
import torch
from collections import OrderedDict
# load the model
model = torch.load('qdtrack-frcnn_r50_fpn_12e_bdd100k-13328aed.pth')
# separate backbone layers
new_state_dict = {}
for name, param in model['state_dict'].items():
if 'backbone' in name:
new_state_dict[name] = param
new_state_dict = OrderedDict(new_state_dict)
# construct new model with backbone layers
new_model = {
'meta': model['meta'],
'state_dict': new_state_dict,
}
# save the new model which has only backbone layers
torch.save(new_model, 'backbone.pth')