[SOLVED] Unexpected key(s) in state_dict: batches_tracked"

I am getting this error when I am loading the weights for a model I trained on a GPU. I am trying to load the model on a CPU. Not sure what this error means. I did check the checkpoint[‘state_dict’] keys and it does contain this ‘num_batches_tracked’ key for all the batch normalization layers. I am using the same model on the CPU as I trained on the GPU. I am training a UNet model.

checkpoint = torch.load(‘model_best.pth-zeromask-notconsider.tar’,map_location=lambda storage, loc: storage)
train_conv.model.load_state_dict(checkpoint[‘state_dict’])

RuntimeError Traceback (most recent call last)
in ()
----> 1 train_conv.model.load_state_dict(checkpoint[‘state_dict’])

~\AppData\Local\conda\conda\envs\myfastai1\lib\site-packages\torch\nn\modules\module.py in load_state_dict(self, state_dict, strict)
719 if len(error_msgs) > 0:
720 raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}’.format(
–> 721 self.class.name, “\n\t”.join(error_msgs)))
722
723 def parameters(self):

RuntimeError: Error(s) in loading state_dict for Unet34:
Unexpected key(s) in state_dict: “rn.1.num_batches_tracked”, “rn.4.0.bn1.num_batches_tracked”, “rn.4.0.bn2.num_batches_tracked”, “rn.4.1.bn1.num_batches_tracked”, “rn.4.1.bn2.num_batches_tracked”, “rn.4.2.bn1.num_batches_tracked”, “rn.4.2.bn2.num_batches_tracked”, “rn.5.0.bn1.num_batches_tracked”, “rn.5.0.bn2.num_batches_tracked”, “rn.5.0.downsample.1.num_batches_tracked”, “rn.5.1.bn1.num_batches_tracked”, “rn.5.1.bn2.num_batches_tracked”, “rn.5.2.bn1.num_batches_tracked”, “rn.5.2.bn2.num_batches_tracked”, “rn.5.3.bn1.num_batches_tracked”, “rn.5.3.bn2.num_batches_tracked”, “rn.6.0.bn1.num_batches_tracked”, “rn.6.0.bn2.num_batches_tracked”, “rn.6.0.downsample.1.num_batches_tracked”, “rn.6.1.bn1.num_batches_tracked”, “rn.6.1.bn2.num_batches_tracked”, “rn.6.2.bn1.num_batches_tracked”, “rn.6.2.bn2.num_batches_tracked”, “rn.6.3.bn1.num_batches_tracked”, “rn.6.3.bn2.num_batches_tracked”, “rn.6.4.bn1.num_batches_tracked”, “rn.6.4.bn2.num_batches_tracked”, “rn.6.5.bn1.num_batches_tracked”, “rn.6.5.bn2.num_batches_tracked”, “rn.7.0.bn1.num_batches_tracked”, “rn.7.0.bn2.num_batches_tracked”, “rn.7.0.downsample.1.num_batches_tracked”, “rn.7.1.bn1.num_batches_tracked”, “rn.7.1.bn2.num_batches_tracked”, “rn.7.2.bn1.num_batches_tracked”, “rn.7.2.bn2.num_batches_tracked”, “up1.bn.num_batches_tracked”, “up2.bn.num_batches_tracked”, “up3.bn.num_batches_tracked”, “up4.bn.num_batches_tracked”.

2 Likes

Looks like it was a version problem - I trained the model in 0.4.1 and tried to load the dict in 0.4.0. Once I upgraded the pytorch version, everything worked fine.

2 Likes

Hi, I meet the same question. I have some models in 0.4.1, at the same time some models in 0.4.0. And my pytorch is 0.4.0. As you say, after upgrading the pytorch version, pytorch can work fine loading 0.4.1’s model. But how about loading 0.4.0’s model?

Thanks !

1 Like

I have the same problem with you, my model and pytorch is 0.4.1, but the problem is not solved, the same error still occurs.

2 Likes

I trained my model in pytorch 0.4.1, in classroom I have 0.4.0 and I can’t upgrade classroom version. it is possible that I could load 0.4.1 model in 0.4.0 or I have to downgrade my pytorch version?
Thanks.

Are there any updates or solutions yet?
I am having a similar trouble loading a model.

Edit: I don’t think that this should be marked as solved, since the solution refers to a pytorch version imcompatibility. This sounds more like a workaround than a solution. To me it should be compatible with multiple versions of pytorch since not everybody can up- or downgrade their version.

2 Likes

You can check pytorch densenet for the solution. The internal state_dict have changed from something like features.denseblock1.denselayer1.norm.1.weight to features.denseblock1.denselayer1.norm1.weight

Also, you can write a simple function that load the model, modify state_dict and save the model.

def modify_state_dict(state_dict):
    pattern = re.compile(
                r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
    state_dict = checkpoint['state_dict']
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]

checkpoint = torch.load(model_path)
modify_state_dict(checkpoint['state_dict'])
torch.save(checkpoint, model_path)

For the num_batches_tracked, pytorch has added in later version. I have checked the value of these key in densenet layer and they are all tensor(0, device='cuda:0'). I think you can add that missing key to the state_dict when modifying it. Btw, I load my trained densenet model on v0.4.1 without that num_batches_tracked key.

Agree with you!!Just downgrade the version to avoid this error is not elegant.