Change Name of State_Dict in a pretrained model

I have a model following DenseNet Architecture, but the problem with is that checkpoints in pre-trained version
Some of the keys for example:
DenseNet121 Version:

densenet121.features.conv0.weight
densenet121.features.norm0.weight
densenet121.features.norm0.bias
densenet121.features.norm0.running_mean
densenet121.features.norm0.running_var
densenet121.features.norm0.num_batches_tracked
densenet121.features.denseblock1.denselayer1.norm1.weight
densenet121.features.denseblock1.denselayer1.norm1.bias
densenet121.features.denseblock1.denselayer1.norm1.running_mean
densenet121.features.denseblock1.denselayer1.norm1.running_var
densenet121.features.denseblock1.denselayer1.norm1.num_batches_tracked
densenet121.features.denseblock1.denselayer1.conv1.weight
densenet121.features.denseblock1.denselayer1.norm2.weight
densenet121.features.denseblock1.denselayer1.norm2.bias
densenet121.features.denseblock1.denselayer1.norm2.running_mean
densenet121.features.denseblock1.denselayer1.norm2.running_var
densenet121.features.denseblock1.denselayer1.norm2.num_batches_tracked
densenet121.features.denseblock1.denselayer1.conv2.weight
densenet121.features.denseblock1.denselayer2.norm1.weight
densenet121.features.denseblock1.denselayer2.norm1.bias
densenet121.features.denseblock1.denselayer2.norm1.running_mean
densenet121.features.denseblock1.denselayer2.norm1.running_var
densenet121.features.denseblock1.denselayer2.norm1.num_batches_tracked
densenet121.features.denseblock1.denselayer2.conv1.weight

Pretrained Version:

module.densenet121.features.conv0.weight
module.densenet121.features.norm0.weight
module.densenet121.features.norm0.bias
module.densenet121.features.norm0.running_mean
module.densenet121.features.norm0.running_var
module.densenet121.features.denseblock1.denselayer1.norm.1.weight
module.densenet121.features.denseblock1.denselayer1.norm.1.bias
module.densenet121.features.denseblock1.denselayer1.norm.1.running_mean
module.densenet121.features.denseblock1.denselayer1.norm.1.running_var
module.densenet121.features.denseblock1.denselayer1.conv.1.weight
module.densenet121.features.denseblock1.denselayer1.norm.2.weight
module.densenet121.features.denseblock1.denselayer1.norm.2.bias
module.densenet121.features.denseblock1.denselayer1.norm.2.running_mean
module.densenet121.features.denseblock1.denselayer1.norm.2.running_var
module.densenet121.features.denseblock1.denselayer1.conv.2.weight
module.densenet121.features.denseblock1.denselayer2.norm.1.weight
module.densenet121.features.denseblock1.denselayer2.norm.1.bias
module.densenet121.features.denseblock1.denselayer2.norm.1.running_mean
module.densenet121.features.denseblock1.denselayer2.norm.1.running_var
module.densenet121.features.denseblock1.denselayer2.conv.1.weight
module.densenet121.features.denseblock1.denselayer2.norm.2.weight
module.densenet121.features.denseblock1.denselayer2.norm.2.bias
module.densenet121.features.denseblock1.denselayer2.norm.2.running_mean
module.densenet121.features.denseblock1.denselayer2.norm.2.running_var
module.densenet121.features.denseblock1.denselayer2.conv.2.weight
module.densenet121.features.denseblock1.denselayer3.norm.1.weight
module.densenet121.features.denseblock1.denselayer3.norm.1.bias
module.densenet121.features.denseblock1.denselayer3.norm.1.running_mean
module.densenet121.features.denseblock1.denselayer3.norm.1.running_var

I am trying to use ChexNet model , Link contains the kays of state_dict of the ChexNet in comparison to ResNet 121,169,201,
There are few GitHub repo trying to replicate ChexNet but while I tried to use their work to get the trained model I faced the same issue in all them,
Errors for Missing and Unexpected Keys while loading the state_dict

if checkpoint != None:
    modelCheckpoint = torch.load(checkpoint)
    model.load_state_dict(modelCheckpoint['state_dict'])

Any help is deeply appreciated !!

You could manually remove the module string in each key using one of these approaches.

The better approach would be to store the state_dict of the plain model (not the nn.DataParallel model) via torch.save(model.module.state_dict(), PATH), which would avoid adding the module names.

1 Like

Also, num_batches_tracked is and extra layer in the newer version of pytorch densenet model, therefore in the pretrained version this layer is missing. Is there a way to leave this layer and copy rest or any suggestions you might have?

You could try to use strice=False in model.load_state_dict to ignore these mismatches.

I assume it is strict=False