Loading a few layers from a pretrained MDNet

I guess you are using this repo. So, in this case, when you load a pth file, say mdnet_vot-otb.pth, it loads a dictionary having only one key: 'shared_layers'. This corresponds to the layers trained offline.

Try it out:

import torch
model_weights = torch.load("mdnet_vot-otb.pth")
print(type(model_weights))
for k in model_weights: print(k)
for k in model_weights['shared_layers']: print("Shared layer", k)

It prints:

<class 'dict'>
shared_layers
Shared layer conv1.0.weight
Shared layer conv1.0.bias
Shared layer conv2.0.weight
Shared layer conv2.0.bias
Shared layer conv3.0.weight
Shared layer conv3.0.bias
Shared layer fc4.0.weight
Shared layer fc4.0.bias
Shared layer fc5.1.weight
Shared layer fc5.1.bias

Your MDNet object has a module called layers. I extracted its structure in the snippet below:

import torch
import torch.nn as nn
from collections import OrderedDict

layers = nn.Sequential(OrderedDict([
                ('conv1', nn.Sequential(nn.Conv2d(3, 96, kernel_size=7, stride=2),
                                        nn.ReLU(inplace=True),
                                        nn.LocalResponseNorm(2),
                                        nn.MaxPool2d(kernel_size=3, stride=2))),
                ('conv2', nn.Sequential(nn.Conv2d(96, 256, kernel_size=5, stride=2),
                                        nn.ReLU(inplace=True),
                                        nn.LocalResponseNorm(2),
                                        nn.MaxPool2d(kernel_size=3, stride=2))),
                ('conv3', nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=1),
                                        nn.ReLU(inplace=True))),
                ('fc4',   nn.Sequential(nn.Linear(512 * 3 * 3, 512),
                                        nn.ReLU(inplace=True))),
                ('fc5',   nn.Sequential(nn.Dropout(0.5),
                                        nn.Linear(512, 512),
                                        nn.ReLU(inplace=True)))]))
for k in layers.state_dict(): print("Module Layer", k)

And it prints:

Module Layer conv1.0.weight
Module Layer conv1.0.bias
Module Layer conv2.0.weight
Module Layer conv2.0.bias
Module Layer conv3.0.weight
Module Layer conv3.0.bias
Module Layer fc4.0.weight
Module Layer fc4.0.bias
Module Layer fc5.1.weight
Module Layer fc5.1.bias

It means then that the keys in the shared_layers match perfectly the keys in the module layer, that’s why this works.

Suppose I changed the fc5 linear layer to nn.Linear(512, 1024). Now If I tried to load the weights it wouldn’t work directly. Here’s a workaround:

import torch
import torch.nn as nn
from collections import OrderedDict

layers = nn.Sequential(OrderedDict([
                ('conv1', nn.Sequential(nn.Conv2d(3, 96, kernel_size=7, stride=2),
                                        nn.ReLU(inplace=True),
                                        nn.LocalResponseNorm(2),
                                        nn.MaxPool2d(kernel_size=3, stride=2))),
                ('conv2', nn.Sequential(nn.Conv2d(96, 256, kernel_size=5, stride=2),
                                        nn.ReLU(inplace=True),
                                        nn.LocalResponseNorm(2),
                                        nn.MaxPool2d(kernel_size=3, stride=2))),
                ('conv3', nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=1),
                                        nn.ReLU(inplace=True))),
                ('fc4',   nn.Sequential(nn.Linear(512 * 3 * 3, 512),
                                        nn.ReLU(inplace=True))),
                ('fc5',   nn.Sequential(nn.Dropout(0.5),
                                        nn.Linear(512, 1024),
                                        nn.ReLU(inplace=True)))]))

model_weights = torch.load("mdnet_vot-otb.pth")

d = model_weights['shared_layers']
d['fc5.1.weight'] = torch.randn((1024, 512)) * 0.01
d['fc5.1.bias'] = torch.zeros(1024)
layers.load_state_dict(d)
2 Likes