Check if a module is in the `state_dict` or not

I’m not sure you will gain that much from filtering out the other layers, as usually you just initialize the model once before the training.
Would it be possible to just initialize all layers and load the pretrained parameters after it or is your use case different?
If this would work, here is a small code example:

def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0.)


pretrained_model = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(10, 20)),
    ('relu1', nn.ReLU()),
    ('fc2', nn.Linear(20, 2))
]))
pretrained_dict = pretrained_model.state_dict()

model = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(10, 20)),
    ('relu1', nn.ReLU()),
    ('fc2', nn.Linear(20, 2)),
    ('relu2', nn.ReLU()),
    ('fc3', nn.Linear(2, 2))
]))
# Initialize model
model.apply(weight_init)
model_dict = model.state_dict()

# Fiter out unneccessary keys
filtered_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)
2 Likes