Check if a module is in the `state_dict` or not

Hey guys, I’m a newbie to pytorch. I want to ask how to check if a module (when iterating over self.modules()) is in the state_dict of a pretrained model or not. I’m building up an autoencoder with a pretrained encoder, so I want to efficiently init the network by only init weights and biais of layers not in the pretrained encoder. Hope I make it clear. Any help would be appreciated, thank you!

I’m not sure you will gain that much from filtering out the other layers, as usually you just initialize the model once before the training.
Would it be possible to just initialize all layers and load the pretrained parameters after it or is your use case different?
If this would work, here is a small code example:

def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0.)


pretrained_model = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(10, 20)),
    ('relu1', nn.ReLU()),
    ('fc2', nn.Linear(20, 2))
]))
pretrained_dict = pretrained_model.state_dict()

model = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(10, 20)),
    ('relu1', nn.ReLU()),
    ('fc2', nn.Linear(20, 2)),
    ('relu2', nn.ReLU()),
    ('fc3', nn.Linear(2, 2))
]))
# Initialize model
model.apply(weight_init)
model_dict = model.state_dict()

# Fiter out unneccessary keys
filtered_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)
2 Likes

Maybe you’re right. I think I just over-complicate things. My code was a bit like yours, first initializing and then loading the pretrained model. and then I just wondered if I can do it in a more efficient way. But it seems I go too much in the details. Thank you for your help.