Extract weights of specific module from pretrained model file

Greetings,

This is Aman Goyal. I am currently pursuing research in CMU and MSU.
I actually have a pre-trained model file which basically has 3 modules: backbone, neck and head. It is basically a deep learning based tracking method.
Now what I need to do is that I just want it’s backbone weights which is basically a Resnet-50 trained on BDD100K detection dataset.
Now is there any possibility of extraction of only backbone weights from the .pth model file which basically contains all modules trained in end-to-end fashion.

The .pth file contains a sorted dictionary. Keys are sorted the same way you called them while instantiating the model.
This basically means if you call model.backbone.state_dict().keys() you should get which ones correspond to the backbone.
Then you just get them from the pth file.

Hello @aman_goyal ,
If I understand correctly, you can extract provided model backbone weights by the following codes:

import torch
model = torch.load('qdtrack-frcnn_r50_fpn_12e_bdd100k-13328aed.pth')
new_state_dict = {}
for name, weight in model['state_dict'].items():
    if 'backbone' in name:
        new_state_dict[name] = weight

This new state dict contains the end-to-end trained weights. you can create a new model with this state dict and add a new neck or head I guess.

Thanks a lot for your prompt response. Can I create a new .pth file consisting of only backbone weights ?

sure you can.
I think you can remove the other layer and keep only the backbone network.
Then save it with your desired name.

new_model = {
    'meta': model['meta'],
    'state_dict': new_state_dict
}
torch.save(new_model, "backbone.pth")

Hope this will help.

Could you mention the code snippet which I would be required to use for saving the backbone weights as .pth file

Hello @aman_goyal ,
please check my previous reply.
I have edited with a code snippet.

@sagorsarker Hey, so I am slightly confused with the 2 code snippets now. So how should they be combined to get backbones.pth ? Could you please combine them in a single comment for more clarity.

Thanks

Hi @aman_goyal ,
Here is the full code snippet about how I tried.

import torch
from collections import OrderedDict

# load the model
model = torch.load('qdtrack-frcnn_r50_fpn_12e_bdd100k-13328aed.pth')

# separate backbone layers
new_state_dict = {}
for name, param in model['state_dict'].items():
    if 'backbone' in name:
        new_state_dict[name] = param

new_state_dict = OrderedDict(new_state_dict)

# construct new model with backbone layers
new_model = {
    'meta': model['meta'],
    'state_dict': new_state_dict,
}

# save the new model which has only backbone layers
torch.save(new_model, 'backbone.pth')

If this works let me know.

1 Like

I did get ‘backbone.pth’ out of this but is there any way to test it to just verify whether it is correct or not?