[SOLVED]State_dict and items?

Hi, thank you always for your help.
I am in trouble in understanding the following code:

filtered_weight_dict = {k: v for k, v in weight_dict.items() if k in model.state_dict()}

Here, weight_dict and “model” is two models. weight_dict was loaded by torch.load from a pretrained model path and the “model” is resnet18.
I cannot understand two points:
(1) why those two models should be compared to store
and
(2) what the items mean in general?

Below is shown weight_dict.items outputs:
encoder.conv1.weight
encoder.bn1.weight
encoder.bn1.bias
encoder.bn1.running_mean
encoder.bn1.running_var
encoder.bn1.num_batches_tracked
encoder.layer1.0.conv1.weight
encoder.layer1.0.bn1.weight
encoder.layer1.0.bn1.bias
encoder.layer1.0.bn1.running_mean
encoder.layer1.0.bn1.running_var
encoder.layer1.0.bn1.num_batches_tracked
encoder.layer1.0.conv2.weight
encoder.layer1.0.bn2.weight
encoder.layer1.0.bn2.bias
encoder.layer1.0.bn2.running_mean
encoder.layer1.0.bn2.running_var
encoder.layer1.0.bn2.num_batches_tracked
encoder.layer1.1.conv1.weight
encoder.layer1.1.bn1.weight
encoder.layer1.1.bn1.bias
encoder.layer1.1.bn1.running_mean
encoder.layer1.1.bn1.running_var
encoder.layer1.1.bn1.num_batches_tracked
encoder.layer1.1.conv2.weight
encoder.layer1.1.bn2.weight
encoder.layer1.1.bn2.bias
encoder.layer1.1.bn2.running_mean
encoder.layer1.1.bn2.running_var
encoder.layer1.1.bn2.num_batches_tracked
encoder.layer2.0.conv1.weight
encoder.layer2.0.bn1.weight
encoder.layer2.0.bn1.bias
encoder.layer2.0.bn1.running_mean
encoder.layer2.0.bn1.running_var
encoder.layer2.0.bn1.num_batches_tracked
encoder.layer2.0.conv2.weight
encoder.layer2.0.bn2.weight
encoder.layer2.0.bn2.bias
encoder.layer2.0.bn2.running_mean
encoder.layer2.0.bn2.running_var
encoder.layer2.0.bn2.num_batches_tracked
encoder.layer2.0.downsample.0.weight
encoder.layer2.0.downsample.1.weight
encoder.layer2.0.downsample.1.bias
encoder.layer2.0.downsample.1.running_mean
encoder.layer2.0.downsample.1.running_var
encoder.layer2.0.downsample.1.num_batches_tracked
encoder.layer2.1.conv1.weight
encoder.layer2.1.bn1.weight
encoder.layer2.1.bn1.bias
encoder.layer2.1.bn1.running_mean
encoder.layer2.1.bn1.running_var
encoder.layer2.1.bn1.num_batches_tracked
encoder.layer2.1.conv2.weight
encoder.layer2.1.bn2.weight
encoder.layer2.1.bn2.bias
encoder.layer2.1.bn2.running_mean
encoder.layer2.1.bn2.running_var
encoder.layer2.1.bn2.num_batches_tracked
encoder.layer3.0.conv1.weight
encoder.layer3.0.bn1.weight
encoder.layer3.0.bn1.bias
encoder.layer3.0.bn1.running_mean
encoder.layer3.0.bn1.running_var
encoder.layer3.0.bn1.num_batches_tracked
encoder.layer3.0.conv2.weight
encoder.layer3.0.bn2.weight
encoder.layer3.0.bn2.bias
encoder.layer3.0.bn2.running_mean
encoder.layer3.0.bn2.running_var
encoder.layer3.0.bn2.num_batches_tracked
encoder.layer3.0.downsample.0.weight
encoder.layer3.0.downsample.1.weight
encoder.layer3.0.downsample.1.bias
encoder.layer3.0.downsample.1.running_mean
encoder.layer3.0.downsample.1.running_var
encoder.layer3.0.downsample.1.num_batches_tracked
encoder.layer3.1.conv1.weight
encoder.layer3.1.bn1.weight
encoder.layer3.1.bn1.bias
encoder.layer3.1.bn1.running_mean
encoder.layer3.1.bn1.running_var
encoder.layer3.1.bn1.num_batches_tracked
encoder.layer3.1.conv2.weight
encoder.layer3.1.bn2.weight
encoder.layer3.1.bn2.bias
encoder.layer3.1.bn2.running_mean
encoder.layer3.1.bn2.running_var
encoder.layer3.1.bn2.num_batches_tracked
encoder.layer4.0.conv1.weight
encoder.layer4.0.bn1.weight
encoder.layer4.0.bn1.bias
encoder.layer4.0.bn1.running_mean
encoder.layer4.0.bn1.running_var
encoder.layer4.0.bn1.num_batches_tracked
encoder.layer4.0.conv2.weight
encoder.layer4.0.bn2.weight
encoder.layer4.0.bn2.bias
encoder.layer4.0.bn2.running_mean
encoder.layer4.0.bn2.running_var
encoder.layer4.0.bn2.num_batches_tracked
encoder.layer4.0.downsample.0.weight
encoder.layer4.0.downsample.1.weight
encoder.layer4.0.downsample.1.bias
encoder.layer4.0.downsample.1.running_mean
encoder.layer4.0.downsample.1.running_var
encoder.layer4.0.downsample.1.num_batches_tracked
encoder.layer4.1.conv1.weight
encoder.layer4.1.bn1.weight
encoder.layer4.1.bn1.bias
encoder.layer4.1.bn1.running_mean
encoder.layer4.1.bn1.running_var
encoder.layer4.1.bn1.num_batches_tracked
encoder.layer4.1.conv2.weight
encoder.layer4.1.bn2.weight
encoder.layer4.1.bn2.bias
encoder.layer4.1.bn2.running_mean
encoder.layer4.1.bn2.running_var
encoder.layer4.1.bn2.num_batches_tracked
encoder.fc.weight
encoder.fc.bias

Thank you in advance.

Hi Petro,

I’m not sure if I understand your first question but I’ll give it a go.

If one doesn’t want to train from scratch every time, but rather use some saved state/weights they would have to load the weights from a file. In this code they load the model state into the weight_dicts variable. When you iterate over it, the k (key) represent the module/layer name and v (value) its corresponding weights. items() is a function you can call on dictionaries to get the keys and values.

By checking if the key from the loaded model exists in the existing model, we could filter the weights to only include the layers that match, and then later load the pre-trained model. Typically, you’d like to check so the keys match to avoid crashes caused by size differences in the fully connected layer at the very end of the network.

1 Like

Thank you for your kind reply, Oli.
I see that the code tries to check and gather the appropriate layers from the existing pretrained model to be included in our model.

So it seems that e.g. if the loaded model have excluded batch normalization then our model will also be or like that.
It is quite a nice and helpful function!

Thank you very much:)

1 Like