How to copy a modified state_dict into a models state_dict

I am trying to copy a modified state_dict from a model that was pruned (e.g. reducing the 0 dimension of one of the tensors by 1). The models have the same keys, the only difference is the dimensions of the tensors. I made a wrapper class to handle pruned networks as seen below:

class PrunedNetwork():
    def __init__(self, network): = network

    def __getattr__(self, name):
        return getattr(, name)

    def forward(self, *args, **kwargs):*args, **kwargs)

    def save(self, filename):

    def load(self, filename):
        new_state_dict = torch.load(filename, 'cpu')

        for param_tensor in
            print(param_tensor, "\t",[param_tensor].size())

        # TODO: copy new_state_dict into state_dict
        # The following attempt doesn't work
        keys =
        for key in keys:
  [key] = new_state_dict[key]

The tensors aren’t being copied over to the model.

I think you are only accessing a copy of the parameters data here:

Why not just this ?

    def load(self, filename):
        new_state_dict = torch.load(filename, 'cpu')


The above won’t work if the dimensions don’t match. So, in your case, I would recreate all network from scratch but having the right dimensions, and simply do a, like below:

import torch

net = torch.nn.Linear(2, 2)
new_state_dict = {'weight': torch.rand(1, 2), 'bias': torch.rand(1)}

net.load_state_dict(new_state_dict) # <- this won't work

new_net = torch.nn.Linear(2, 1)
new_net.load_state_dict(new_state_dict) # <- this works

@LeviViana The problem is I currently have about 20 weight files of pruned networks (more on the way) so recreating every network by hand isn’t really doable. The goal is to read a network from a config file, edit the tensor dimensions to be the same as the pruned weight file and then load in the weights.

Another thing I don’t understand is that[param_tensor][0][0] = 0

sets the values of this tensors dimension to 0 but assigning the full tensor doesn’t retain the changes.

@ptrblck @albanD , since you’ve helped me before I was wondering if you had any idea? Thank you in advance.


As mentioned by @LeviViana,[key] = new_state_dict[key] changes the dictionary created when calling (not the Tensor). So it won’t modify the network itself.[param_tensor][0][0] = 0 does change the original Tensor, because the indexing returns a view of the original Tensor that you modify.
The right way to change the whole Tensor is to change it inplace, not the dict:[key].copy_(new_state_dict[key])

Note that (most likely) you don’t want such changes to be tracked by the autograd, so you should wrap these changes (inplace copy) into a with torch.no_grad():.

For your original question, the goal is to load weights that have different sizes from the current ones?
How do you handle that for layers parameters like in_channels/out_channels?



results in a

RuntimeError: The size of tensor a (32) must match the size of tensor b (13) at non-singleton dimension 0

since the given tensors have a reduced dimension.

The goal is to load pruned networks (in my case pruned Yolov2 networks) that have been saved during pruning process. The idea was to create a wrapper class that initiates the original network and adjusts it so it can load in the weights of a pruned version of itself, without knowing beforehand what the size of each tensor will be. I hadn’t fully considered the layers parameters yet but I assumed these were accessible from the saved file aswell?

The saved file only contains the state dict. So only the dictionary of parameters. You won’t find the layers parameters there, unless you save them explicitly on the side.

If you can load the layer informations before and create the PrunedNetwork with the right size, then it might be easier to load it afterwards.

@albanD Unfortunately this isn’t an option for me anymore since I have a deadline pretty soon and repruning the networks takes several days. I might be able to change each layers parameters based on the size of the tensors within. But I still need a way to edit the original network’s dictionary in this case.

Editing the dictionnary won’t help. As the dictionnary is created when you call .state_dict().
If you want to change them, the simplest solution I can think of right now is going to go through the network and for each parameter:

# mod is the current module that has a parameter weight that needs
# to be changed to new_weight
del mod.weight
mod.weight = nn.Parameter(new_weight)

@albanD I managed to to recreate the network and copy over all tensors by running through the keys and assigning them using some string parsing magic to[firstindex][secondindex]

I do however still receive an error:

RuntimeError: the derivative for 'running_mean' is not implemented

which I looked up and I understand that these tensors don’t have to be wrapped by nn.Parameter().
However when I leave out the nn.Parameter() I get the following error:

RuntimeError: Error(s) in loading state_dict for Yolo:
        Unexpected key(s) in state_dict: "layers.0.1_convbatch.layers.1.running_mean", "layers.0.1_convbatch.layers.1.running_var", ...... 

How do I copy the running var and running mean tensors?

These are not nn.Parameter but buffers. They are plain Tensors and are registered with register_buffer().
These buffers should not require gradients. The error you’re seeing is because the formula to compute the gradient wrt the running statistics is not implemented for batchnorm.

@albanD I don’t think I fully understand. How do I solve this issue?

To solve “RuntimeError: the derivative for ‘running_mean’ is not implemented”, you need to make sure that batchnorm_mod.running_mean.requires_grad == False.
What might be happening is that in your loading script, you treat both Parameters and buffers the same way, and so you make all of the require gradients. But you want to be careful and make sure that you don’t set requires_grad=True for the buffers.


locallayer.register_buffer('running_mean', new_state_dict[param_tensor])

This did the trick! Thank you very much again for your help and patience.