Modifying the state_dict changes the values of the parameters

Hello,

I am currently trying to map the parameters of a pre-trained network to another network of the exact same architecture, but with a different arrangement of sub-modules. Here is how I am doing it:

def _load_parameters(self):
        
        self.old_state_dict = torch.load(f = self.param_path,
                                         map_location = torch.device('cpu'))
        
        # make a copy to use load_state_dict() method later
        
        state_dict = copy.deepcopy(self.state_dict())
        
        with open('state_dict_map.csv',newline='') as csvfile:
            reader = csv.reader(csvfile,delimiter=',')
            for old_key,new_key in reader:
                state_dict[new_key] = self.old_state_dict[old_key]
        
        self.load_state_dict(state_dict)

This function is defined in the Net() class that also contains the forward() method. The state_dict_map.csv file contains mappings from old parameter names to new parameter names. For example:

module.layer1.0.weight,stage1.0.weight
module.layer1.0.bias,stage1.0.bias
module.layer1.1.weight,stage1.1.weight

While debugging, just before I run self.load_state_dict(state_dict), I check the following parameters:

>>> self.old_state_dict[module.layer1.1.weight]
tensor([1.1605, 1.0575, 0.8263, 1.1720, 0.8299, 0.9742, 0.7409,
1.0015, 0.9154, 0.7924, 0.7753, 1.1176, 0.8313, 0.9920, 1.0639, 1.0233])
>>> state_dict['stage1.1.weight']
tensor([1.1605, 1.0575, 0.8263, 1.1720, 0.8299, 0.9742, 0.7409,
1.0015, 0.9154, 0.7924, 0.7753, 1.1176, 0.8313, 0.9920, 1.0639, 1.0233])

Clearly, they are the same, which is expected. Also, the original state_dict has not yet been modified:

>>> self.state_dict()['stage1.1.weight']
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

However, after I run self.load_state_dict(state_dict), I get this:

>>> self.state_dict()['stage1.1.weight']
tensor([1.0268, 1.0117, 1.0257, 1.0075, 0.9960, 0.9114, 0.9855,
0.9940, 0.9469, 0.9148, 0.9430, 1.0411, 0.9435, 1.1579, 0.9916, 1.0996])

Which is not the same as:

>>> state_dict['stage1.1.weight']
tensor([1.1605, 1.0575, 0.8263, 1.1720, 0.8299, 0.9742, 0.7409,
1.0015, 0.9154, 0.7924, 0.7753, 1.1176, 0.8313, 0.9920, 1.0639, 1.0233])

FYI, module.layer1.1.weight and stage1.1.weight are batch normalization layers, but I get the same problem for all layers in the net, including convolutional layers. I am not sure where exactly these numbers came from. Any suggestions?

Thanks.

That looks indeed a bit weird.
Unfortunately, I cannot reproduce it with this small code snippet:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.bn = nn.BatchNorm2d(3)
        
    def _load_param(self, old_state_dict):
        state_dict = copy.deepcopy(self.state_dict()) 
        for key in state_dict:
            old_key = key.replace('bn.', '')
            state_dict[key] = old_state_dict[old_key]
        
        self.load_state_dict(state_dict)

bn_old = nn.BatchNorm2d(3)
for _ in range(100):
    out = bn_old(torch.randn(8, 3, 24, 24) * 5 + 7)

model = MyModel()
print(model.state_dict())

model._load_param(bn_old.state_dict())
print(model.state_dict())
print(bn_old.state_dict())

Could you check, what might be missing in my code and post a code snippet to reproduce this issue?

Hi @ptrblck,

Thanks a lot for the quick response. Here is my full code:

import torch
from torch import nn
import csv
import copy

class Net(nn.Module):
    
    def __init__(self,param_path,state_dict_map_path):
        
        # run nn.Module's constructor
        
        super(Net,self).__init__()
        
        # path to .pkl file
        
        self.param_path = param_path
        
        # path to .csv file
        
        self.state_dict_map_path = state_dict_map_path
        
        # build net
        
        in_channels = 16
        
        conv1 = nn.Conv2d(in_channels = 1,
                         out_channels = in_channels,
                         kernel_size = (3,3),
                         padding = (1,1))
        
        batch_norm = nn.BatchNorm2d(num_features = 16)
        
        activation = nn.ReLU()
        
        conv2 = nn.Conv2d(in_channels = in_channels,
                         out_channels = in_channels,
                         kernel_size = (3,3),
                         padding = (1,1))
        
        pooling = nn.MaxPool2d(kernel_size = (2,2))
        
        # first stage
        
        stages = [nn.Sequential(conv1,
                                batch_norm,
                                activation,
                                conv2,
                                batch_norm,
                                activation,
                                pooling)]
        
        # next 4 stages
        
        for i in range(4):
            
            conv1 = nn.Conv2d(in_channels = in_channels,
                              out_channels = in_channels * 2,
                              kernel_size = (3,3),
                              padding = (1,1))
            
            batch_norm = nn.BatchNorm2d(num_features = in_channels * 2)
            
            conv2 = nn.Conv2d(in_channels = in_channels * 2,
                              out_channels = in_channels * 2,
                              kernel_size = (3,3),
                              padding = (1,1))
            
            stages += [nn.Sequential(conv1,
                                     batch_norm,
                                     activation,
                                     conv2,
                                     batch_norm,
                                     activation,
                                     pooling)]
            
            in_channels = in_channels * 2
            
        # 6th stage, in_channels = 256
        
        conv1 = nn.Conv2d(in_channels = in_channels,
                          out_channels = in_channels * 2,
                          kernel_size = (3,3),
                          padding = (1,1))
            
        batch_norm = nn.BatchNorm2d(num_features = in_channels * 2)
            
        stages += [nn.Sequential(conv1,
                                 batch_norm,
                                 activation,
                                 pooling)]
        
        in_channels = in_channels * 2
        
        # final stage, in_channels = 512
        
        conv1 = nn.Conv2d(in_channels = in_channels,
                          out_channels = in_channels * 2,
                          kernel_size = (2,2),
                          padding = (0,0))
            
        batch_norm = nn.BatchNorm2d(num_features = in_channels * 2)
            
        stages += [nn.Sequential(conv1,
                                 batch_norm,
                                 activation)]
        
        # assign names to the stages for the state_dict
        
        self.stage1 = stages[0]
        self.stage2 = stages[1]
        self.stage3 = stages[2]
        self.stage4 = stages[3]
        self.stage5 = stages[4]
        self.stage6 = stages[5]
        self.stage7 = stages[6]
        
        # load the pre-trained parameters
        
        self._load_parameters()
    
    # assign parameters from file to layers
    
    def _load_parameters(self):
        
        self.old_state_dict = torch.load(f = self.param_path,
                                         map_location = torch.device('cpu'))
        
        # make a copy to use load_state_dict() method later
        
        state_dict = copy.deepcopy(self.state_dict())
        
        with open(self.state_dict_map_path,newline='') as csvfile:
            reader = csv.reader(csvfile,delimiter=',')
            for old_key,new_key in reader:
                state_dict[new_key] = self.old_state_dict[old_key]
        
        self.load_state_dict(state_dict)
    
    def forward(self,x):
        
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.stage6(x)
        x = self.stage7(x)
        
        # average of all pixels in each feature map
        
        x = nn.functional.avg_pool2d(input = x,
                                     kernel_size = x.shape[2:])
        
        # flatten from N x 1024 x 1 x 1 to N x 1024
        
        x = torch.flatten(input = x,
                           start_dim = 1)
        
        return x

net = Net(param_path,state_dict_map_path)

My code is very similar to yours, except I included the self._load_parameters() line at the end of the constructor instead of running this method outside the class. Does this make a difference? If not, any other suggestions?

By the way @ptrblck, you can ignore most of the code in the constructor, except the relevant line self._load_parameters().

Could you add a code snippet which would create the state_dict in the desired format, so that I could use your current code to load and debug it, please? I cannot see any obvious issues and would need to dig a bit into the code.

Hi @ptrblck,

I wasn’t sure what you meant, but I made a fully contained example:

import torch
from torch import nn
import csv
import copy

class Net(nn.Module):
    
    def __init__(self,old_state_dict,state_dict_map):
        
        # run nn.Module's constructor
        
        super(Net,self).__init__()
        
        #--------------------------------------------------------------------
        # build net
        
        in_channels = 16
        
        conv1 = nn.Conv2d(in_channels = 1,
                         out_channels = in_channels,
                         kernel_size = (3,3),
                         padding = (1,1))
        
        batch_norm = nn.BatchNorm2d(num_features = 16)
        
        activation = nn.ReLU()
        
        conv2 = nn.Conv2d(in_channels = in_channels,
                         out_channels = in_channels,
                         kernel_size = (3,3),
                         padding = (1,1))
        
        pooling = nn.MaxPool2d(kernel_size = (2,2))
        
        # first stage
        
        stages = [nn.Sequential(conv1,
                                batch_norm,
                                activation,
                                conv2,
                                batch_norm,
                                activation,
                                pooling)]
        
        # next 4 stages
        
        for i in range(4):
            
            conv1 = nn.Conv2d(in_channels = in_channels,
                              out_channels = in_channels * 2,
                              kernel_size = (3,3),
                              padding = (1,1))
            
            batch_norm = nn.BatchNorm2d(num_features = in_channels * 2)
            
            conv2 = nn.Conv2d(in_channels = in_channels * 2,
                              out_channels = in_channels * 2,
                              kernel_size = (3,3),
                              padding = (1,1))
            
            stages += [nn.Sequential(conv1,
                                     batch_norm,
                                     activation,
                                     conv2,
                                     batch_norm,
                                     activation,
                                     pooling)]
            
            in_channels = in_channels * 2
            
        # 6th stage, in_channels = 256
        
        conv1 = nn.Conv2d(in_channels = in_channels,
                          out_channels = in_channels * 2,
                          kernel_size = (3,3),
                          padding = (1,1))
            
        batch_norm = nn.BatchNorm2d(num_features = in_channels * 2)
            
        stages += [nn.Sequential(conv1,
                                 batch_norm,
                                 activation,
                                 pooling)]
        
        in_channels = in_channels * 2
        
        # final stage, in_channels = 512
        
        conv1 = nn.Conv2d(in_channels = in_channels,
                          out_channels = in_channels * 2,
                          kernel_size = (2,2),
                          padding = (0,0))
            
        batch_norm = nn.BatchNorm2d(num_features = in_channels * 2)
            
        stages += [nn.Sequential(conv1,
                                 batch_norm,
                                 activation)]
        
        # assign names to the stages for the state_dict
        
        self.stage1 = stages[0]
        self.stage2 = stages[1]
        self.stage3 = stages[2]
        self.stage4 = stages[3]
        self.stage5 = stages[4]
        self.stage6 = stages[5]
        self.stage7 = stages[6]
        
        #--------------------------------------------------------------------
        
        # load the pre-trained parameters
        
        # make a copy to use load_state_dict() method later
        
        state_dict = copy.deepcopy(self.state_dict())
        
        for new_key,old_key in state_dict_map.items():
            state_dict[new_key] = old_state_dict[old_key]
            
        self.load_state_dict(state_dict)
    
    def forward(self,x):
        
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.stage6(x)
        x = self.stage7(x)
        
        # average of all pixels in each feature map
        
        x = nn.functional.avg_pool2d(input = x,
                                     kernel_size = x.shape[2:])
        
        # flatten from N x 1024 x 1 x 1 to N x 1024
        
        x = torch.flatten(input = x,
                           start_dim = 1)
        
        return x

torch.manual_seed(42)

old_state_dict = {'module.layer1.0.weight':torch.randn((16, 1, 3, 3)),
                  'module.layer1.0.bias':torch.randn((16)),
                  'module.layer1.1.weight':torch.randn((16)),
                  'module.layer1.1.bias':torch.randn((16)),
                  'module.layer1.1.running_mean':torch.randn((16)),
                  'module.layer1.1.running_var':torch.randn((16)),
                  'module.layer2.0.weight':torch.randn((16, 16, 3, 3)),
                  'module.layer2.0.bias':torch.randn((16)),
                  'module.layer2.1.weight':torch.randn((16)),
                  'module.layer2.1.bias':torch.randn((16)),
                  'module.layer2.1.running_mean':torch.randn((16)),
                  'module.layer2.1.running_var':torch.randn((16)),
                  'module.layer4.0.weight':torch.randn((32, 16, 3, 3)),
                  'module.layer4.0.bias':torch.randn((32)),
                  'module.layer4.1.weight':torch.randn((32)),
                  'module.layer4.1.bias':torch.randn((32)),
                  'module.layer4.1.running_mean':torch.randn((32)),
                  'module.layer4.1.running_var':torch.randn((32)),
                  'module.layer5.0.weight':torch.randn((32, 32, 3, 3)),
                  'module.layer5.0.bias':torch.randn((32)),
                  'module.layer5.1.weight':torch.randn((32)),
                  'module.layer5.1.bias':torch.randn((32)),
                  'module.layer5.1.running_mean':torch.randn((32)),
                  'module.layer5.1.running_var':torch.randn((32)),
                  'module.layer7.0.weight':torch.randn((64, 32, 3, 3)),
                  'module.layer7.0.bias':torch.randn((64)),
                  'module.layer7.1.weight':torch.randn((64)),
                  'module.layer7.1.bias':torch.randn((64)),
                  'module.layer7.1.running_mean':torch.randn((64)),
                  'module.layer7.1.running_var':torch.randn((64)),
                  'module.layer8.0.weight':torch.randn((64, 64, 3, 3)),
                  'module.layer8.0.bias':torch.randn((64)),
                  'module.layer8.1.weight':torch.randn((64)),
                  'module.layer8.1.bias':torch.randn((64)),
                  'module.layer8.1.running_mean':torch.randn((64)),
                  'module.layer8.1.running_var':torch.randn((64)),
                  'module.layer10.0.weight':torch.randn((128, 64, 3, 3)),
                  'module.layer10.0.bias':torch.randn((128)),
                  'module.layer10.1.weight':torch.randn((128)),
                  'module.layer10.1.bias':torch.randn((128)),
                  'module.layer10.1.running_mean':torch.randn((128)),
                  'module.layer10.1.running_var':torch.randn((128)),
                  'module.layer11.0.weight':torch.randn((128, 128, 3, 3)),
                  'module.layer11.0.bias':torch.randn((128)),
                  'module.layer11.1.weight':torch.randn((128)),
                  'module.layer11.1.bias':torch.randn((128)),
                  'module.layer11.1.running_mean':torch.randn((128)),
                  'module.layer11.1.running_var':torch.randn((128)),
                  'module.layer13.0.weight':torch.randn((256, 128, 3, 3)),
                  'module.layer13.0.bias':torch.randn((256)),
                  'module.layer13.1.weight':torch.randn((256)),
                  'module.layer13.1.bias':torch.randn((256)),
                  'module.layer13.1.running_mean':torch.randn((256)),
                  'module.layer13.1.running_var':torch.randn((256)),
                  'module.layer14.0.weight':torch.randn((256, 256, 3, 3)),
                  'module.layer14.0.bias':torch.randn((256)),
                  'module.layer14.1.weight':torch.randn((256)),
                  'module.layer14.1.bias':torch.randn((256)),
                  'module.layer14.1.running_mean':torch.randn((256)),
                  'module.layer14.1.running_var':torch.randn((256)),
                  'module.layer16.0.weight':torch.randn((512, 256, 3, 3)),
                  'module.layer16.0.bias':torch.randn((512)),
                  'module.layer16.1.weight':torch.randn((512)),
                  'module.layer16.1.bias':torch.randn((512)),
                  'module.layer16.1.running_mean':torch.randn((512)),
                  'module.layer16.1.running_var':torch.randn((512)),
                  'module.layer18.0.weight':torch.randn((1024, 512, 2, 2)),
                  'module.layer18.0.bias':torch.randn((1024)),
                  'module.layer18.1.weight':torch.randn((1024)),
                  'module.layer18.1.bias':torch.randn((1024)),
                  'module.layer18.1.running_mean':torch.randn((1024)),
                  'module.layer18.1.running_var':torch.randn((1024))}

state_dict_map = {'stage1.0.weight':'module.layer1.0.weight',
                  'stage1.0.bias':'module.layer1.0.bias',
                  'stage1.1.weight':'module.layer1.1.weight',
                  'stage1.1.bias':'module.layer1.1.bias',
                  'stage1.1.running_mean':'module.layer1.1.running_mean',
                  'stage1.1.running_var':'module.layer1.1.running_var',
                  'stage1.3.weight':'module.layer2.0.weight',
                  'stage1.3.bias':'module.layer2.0.bias',
                  'stage1.4.weight':'module.layer2.1.weight',
                  'stage1.4.bias':'module.layer2.1.bias',
                  'stage1.4.running_mean':'module.layer2.1.running_mean',
                  'stage1.4.running_var':'module.layer2.1.running_var',
                  'stage2.0.weight':'module.layer4.0.weight',
                  'stage2.0.bias':'module.layer4.0.bias',
                  'stage2.1.weight':'module.layer4.1.weight',
                  'stage2.1.bias':'module.layer4.1.bias',
                  'stage2.1.running_mean':'module.layer4.1.running_mean',
                  'stage2.1.running_var':'module.layer4.1.running_var',
                  'stage2.3.weight':'module.layer5.0.weight',
                  'stage2.3.bias':'module.layer5.0.bias',
                  'stage2.4.weight':'module.layer5.1.weight',
                  'stage2.4.bias':'module.layer5.1.bias',
                  'stage2.4.running_mean':'module.layer5.1.running_mean',
                  'stage2.4.running_var':'module.layer5.1.running_var',
                  'stage3.0.weight':'module.layer7.0.weight',
                  'stage3.0.bias':'module.layer7.0.bias',
                  'stage3.1.weight':'module.layer7.1.weight',
                  'stage3.1.bias':'module.layer7.1.bias',
                  'stage3.1.running_mean':'module.layer7.1.running_mean',
                  'stage3.1.running_var':'module.layer7.1.running_var',
                  'stage3.3.weight':'module.layer8.0.weight',
                  'stage3.3.bias':'module.layer8.0.bias',
                  'stage3.4.weight':'module.layer8.1.weight',
                  'stage3.4.bias':'module.layer8.1.bias',
                  'stage3.4.running_mean':'module.layer8.1.running_mean',
                  'stage3.4.running_var':'module.layer8.1.running_var',
                  'stage4.0.weight':'module.layer10.0.weight',
                  'stage4.0.bias':'module.layer10.0.bias',
                  'stage4.1.weight':'module.layer10.1.weight',
                  'stage4.1.bias':'module.layer10.1.bias',
                  'stage4.1.running_mean':'module.layer10.1.running_mean',
                  'stage4.1.running_var':'module.layer10.1.running_var',
                  'stage4.3.weight':'module.layer11.0.weight',
                  'stage4.3.bias':'module.layer11.0.bias',
                  'stage4.4.weight':'module.layer11.1.weight',
                  'stage4.4.bias':'module.layer11.1.bias',
                  'stage4.4.running_mean':'module.layer11.1.running_mean',
                  'stage4.4.running_var':'module.layer11.1.running_var',
                  'stage5.0.weight':'module.layer13.0.weight',
                  'stage5.0.bias':'module.layer13.0.bias',
                  'stage5.1.weight':'module.layer13.1.weight',
                  'stage5.1.bias':'module.layer13.1.bias',
                  'stage5.1.running_mean':'module.layer13.1.running_mean',
                  'stage5.1.running_var':'module.layer13.1.running_var',
                  'stage5.3.weight':'module.layer14.0.weight',
                  'stage5.3.bias':'module.layer14.0.bias',
                  'stage5.4.weight':'module.layer14.1.weight',
                  'stage5.4.bias':'module.layer14.1.bias',
                  'stage5.4.running_mean':'module.layer14.1.running_mean',
                  'stage5.4.running_var':'module.layer14.1.running_var',
                  'stage6.0.weight':'module.layer16.0.weight',
                  'stage6.0.bias':'module.layer16.0.bias',
                  'stage6.1.weight':'module.layer16.1.weight',
                  'stage6.1.bias':'module.layer16.1.bias',
                  'stage6.1.running_mean':'module.layer16.1.running_mean',
                  'stage6.1.running_var':'module.layer16.1.running_var',
                  'stage7.0.weight':'module.layer18.0.weight',
                  'stage7.0.bias':'module.layer18.0.bias',
                  'stage7.1.weight':'module.layer18.1.weight',
                  'stage7.1.bias':'module.layer18.1.bias',
                  'stage7.1.running_mean':'module.layer18.1.running_mean',
                  'stage7.1.running_var':'module.layer18.1.running_var'}

net = Net(old_state_dict,state_dict_map)

for key,value in net.state_dict().items():
    
    if 'num_batches_tracked' in key:
        continue
    
    param1 = value
    param2 = old_state_dict[state_dict_map[key]]
    
    is_equal = torch.allclose(param1,param2)
    
    print(is_equal)

However, after I run this, I see the following in the console:

True
True
False
False
False
False
True
True
True
True
True
True
True
True
False
False
False
False
True
True
True
True
True
True
True
True
False
False
False
False
True
True
True
True
True
True
True
True
False
False
False
False
True
True
True
True
True
True
True
True
False
False
False
False
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True

Which means that some parameter values in net.state_dict() are not correct. For example, the first False that you see above corresponds to the 'stage1.1.weight' key. Comparing these two parameters:

>>> net.state_dict()['stage1.1.weight']
tensor([ 1.3956, -0.4016, -0.4760,  0.6024, -0.1390, -0.5199, -0.4298,
-0.9330,-0.3256,  0.9304, -0.2840,  0.8464,  0.0186, -1.6756, -1.9437,  0.0987])
>>> old_state_dict[state_dict_map['stage1.1.weight']]
tensor([-1.9006,  0.2286,  0.0249, -0.3460,  0.2868, -0.7308,  0.1748, 
-1.0939,-1.6022,  1.3529,  1.2888,  0.0523, -1.5469,  0.7567,  0.7755,  2.0265])

Again, I am not sure where these numbers came from. In this case, net.state_dict()['stage1.1.weight'] should take the value of old_state_dict[state_dict_map['stage1.1.weight']].

I also want to add that while debugging here:

for new_key,old_key in state_dict_map.items():
    state_dict[new_key] = old_state_dict[old_key]

In the case where new_key == 'stage1.1.weight' and old_key == 'module.layer1.1.weight', then state_dict[new_key] was indeed equal to old_state_dict[old_key], which suggests that the problem is with self.load_state_dict(state_dict).

I would greatly appreciate your feedback on this. Thanks a lot.

Hi @ptrblck,

Did you get a chance to look at the example above? Please let me know if anything is not clear.

@Mahmoud_Abdelkhalek Just for fun I am exploring this problem myself.
Got it. The problem is that you are reusing the same batch norm layer multiple times. When you copy thebatch norm value, it’s gonna change the values to the previous batch norm modules too (because they are the same module!).
The problem is here:

stages = [nn.Sequential(conv1,
                        batch_norm,
                        activation,
                        conv2,
                        batch_norm,
                        activation,
                        pooling)]

Use different batch norm layers and you problem will be solved.

1 Like

Hi @Valerio_Biscione, could you possibly post your full working solution?

Here you go

import torch
from torch import nn
import csv
import copy


class Net(nn.Module):

    def __init__(self, old_state_dict, state_dict_map):

        # run nn.Module's constructor

        super(Net, self).__init__()

        # --------------------------------------------------------------------
        # build net

        in_channels = 16

        conv1 = nn.Conv2d(in_channels=1,
                          out_channels=in_channels,
                          kernel_size=(3, 3),
                          padding=(1, 1))

        batch_norm = nn.BatchNorm2d(num_features=16)
        batch_norm2 = nn.BatchNorm2d(num_features=16)

        activation = nn.ReLU()

        conv2 = nn.Conv2d(in_channels=in_channels,
                          out_channels=in_channels,
                          kernel_size=(3, 3),
                          padding=(1, 1))

        pooling = nn.MaxPool2d(kernel_size=(2, 2))

        # first stage

        stages = [nn.Sequential(conv1,
                                batch_norm,
                                activation,
                                conv2,
                                batch_norm2,
                                activation,
                                pooling)]

        # next 4 stages

        for i in range(4):
            conv1 = nn.Conv2d(in_channels=in_channels,
                              out_channels=in_channels * 2,
                              kernel_size=(3, 3),
                              padding=(1, 1))

            batch_norm = nn.BatchNorm2d(num_features=in_channels * 2)
            batch_norm2 = nn.BatchNorm2d(num_features=in_channels * 2)

            conv2 = nn.Conv2d(in_channels=in_channels * 2,
                              out_channels=in_channels * 2,
                              kernel_size=(3, 3),
                              padding=(1, 1))

            stages += [nn.Sequential(conv1,
                                     batch_norm,
                                     activation,
                                     conv2,
                                     batch_norm2,
                                     activation,
                                     pooling)]

            in_channels = in_channels * 2

        # 6th stage, in_channels = 256

        conv1 = nn.Conv2d(in_channels=in_channels,
                          out_channels=in_channels * 2,
                          kernel_size=(3, 3),
                          padding=(1, 1))

        batch_norm = nn.BatchNorm2d(num_features=in_channels * 2)

        stages += [nn.Sequential(conv1,
                                 batch_norm,
                                 activation,
                                 pooling)]

        in_channels = in_channels * 2

        # final stage, in_channels = 512

        conv1 = nn.Conv2d(in_channels=in_channels,
                          out_channels=in_channels * 2,
                          kernel_size=(2, 2),
                          padding=(0, 0))

        batch_norm = nn.BatchNorm2d(num_features=in_channels * 2)

        stages += [nn.Sequential(conv1,
                                 batch_norm,
                                 activation)]

        # assign names to the stages for the state_dict

        self.stage1 = stages[0]
        self.stage2 = stages[1]
        self.stage3 = stages[2]
        self.stage4 = stages[3]
        self.stage5 = stages[4]
        self.stage6 = stages[5]
        self.stage7 = stages[6]

        # --------------------------------------------------------------------

        # load the pre-trained parameters

        # make a copy to use load_state_dict() method later

        state_dict = copy.deepcopy(self.state_dict())

        for new_key, old_key in state_dict_map.items():
            state_dict[new_key] = old_state_dict[old_key]

        self.load_state_dict(state_dict)

    def forward(self, x):

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.stage6(x)
        x = self.stage7(x)

        # average of all pixels in each feature map

        x = nn.functional.avg_pool2d(input=x,
                                     kernel_size=x.shape[2:])

        # flatten from N x 1024 x 1 x 1 to N x 1024

        x = torch.flatten(input=x,
                          start_dim=1)

        return x


torch.manual_seed(42)

old_state_dict = {'module.layer1.0.weight': torch.randn((16, 1, 3, 3)),
                  'module.layer1.0.bias': torch.randn((16)),
                  'module.layer1.1.weight': torch.randn((16)),
                  'module.layer1.1.bias': torch.randn((16)),
                  'module.layer1.1.running_mean': torch.randn((16)),
                  'module.layer1.1.running_var': torch.randn((16)),
                  'module.layer2.0.weight': torch.randn((16, 16, 3, 3)),
                  'module.layer2.0.bias': torch.randn((16)),
                  'module.layer2.1.weight': torch.randn((16)),
                  'module.layer2.1.bias': torch.randn((16)),
                  'module.layer2.1.running_mean': torch.randn((16)),
                  'module.layer2.1.running_var': torch.randn((16)),
                  'module.layer4.0.weight': torch.randn((32, 16, 3, 3)),
                  'module.layer4.0.bias': torch.randn((32)),
                  'module.layer4.1.weight': torch.randn((32)),
                  'module.layer4.1.bias': torch.randn((32)),
                  'module.layer4.1.running_mean': torch.randn((32)),
                  'module.layer4.1.running_var': torch.randn((32)),
                  'module.layer5.0.weight': torch.randn((32, 32, 3, 3)),
                  'module.layer5.0.bias': torch.randn((32)),
                  'module.layer5.1.weight': torch.randn((32)),
                  'module.layer5.1.bias': torch.randn((32)),
                  'module.layer5.1.running_mean': torch.randn((32)),
                  'module.layer5.1.running_var': torch.randn((32)),
                  'module.layer7.0.weight': torch.randn((64, 32, 3, 3)),
                  'module.layer7.0.bias': torch.randn((64)),
                  'module.layer7.1.weight': torch.randn((64)),
                  'module.layer7.1.bias': torch.randn((64)),
                  'module.layer7.1.running_mean': torch.randn((64)),
                  'module.layer7.1.running_var': torch.randn((64)),
                  'module.layer8.0.weight': torch.randn((64, 64, 3, 3)),
                  'module.layer8.0.bias': torch.randn((64)),
                  'module.layer8.1.weight': torch.randn((64)),
                  'module.layer8.1.bias': torch.randn((64)),
                  'module.layer8.1.running_mean': torch.randn((64)),
                  'module.layer8.1.running_var': torch.randn((64)),
                  'module.layer10.0.weight': torch.randn((128, 64, 3, 3)),
                  'module.layer10.0.bias': torch.randn((128)),
                  'module.layer10.1.weight': torch.randn((128)),
                  'module.layer10.1.bias': torch.randn((128)),
                  'module.layer10.1.running_mean': torch.randn((128)),
                  'module.layer10.1.running_var': torch.randn((128)),
                  'module.layer11.0.weight': torch.randn((128, 128, 3, 3)),
                  'module.layer11.0.bias': torch.randn((128)),
                  'module.layer11.1.weight': torch.randn((128)),
                  'module.layer11.1.bias': torch.randn((128)),
                  'module.layer11.1.running_mean': torch.randn((128)),
                  'module.layer11.1.running_var': torch.randn((128)),
                  'module.layer13.0.weight': torch.randn((256, 128, 3, 3)),
                  'module.layer13.0.bias': torch.randn((256)),
                  'module.layer13.1.weight': torch.randn((256)),
                  'module.layer13.1.bias': torch.randn((256)),
                  'module.layer13.1.running_mean': torch.randn((256)),
                  'module.layer13.1.running_var': torch.randn((256)),
                  'module.layer14.0.weight': torch.randn((256, 256, 3, 3)),
                  'module.layer14.0.bias': torch.randn((256)),
                  'module.layer14.1.weight': torch.randn((256)),
                  'module.layer14.1.bias': torch.randn((256)),
                  'module.layer14.1.running_mean': torch.randn((256)),
                  'module.layer14.1.running_var': torch.randn((256)),
                  'module.layer16.0.weight': torch.randn((512, 256, 3, 3)),
                  'module.layer16.0.bias': torch.randn((512)),
                  'module.layer16.1.weight': torch.randn((512)),
                  'module.layer16.1.bias': torch.randn((512)),
                  'module.layer16.1.running_mean': torch.randn((512)),
                  'module.layer16.1.running_var': torch.randn((512)),
                  'module.layer18.0.weight': torch.randn((1024, 512, 2, 2)),
                  'module.layer18.0.bias': torch.randn((1024)),
                  'module.layer18.1.weight': torch.randn((1024)),
                  'module.layer18.1.bias': torch.randn((1024)),
                  'module.layer18.1.running_mean': torch.randn((1024)),
                  'module.layer18.1.running_var': torch.randn((1024))}

state_dict_map = {'stage1.0.weight': 'module.layer1.0.weight',
                  'stage1.0.bias': 'module.layer1.0.bias',
                  'stage1.1.weight': 'module.layer1.1.weight',
                  'stage1.1.bias': 'module.layer1.1.bias',
                  'stage1.1.running_mean': 'module.layer1.1.running_mean',
                  'stage1.1.running_var': 'module.layer1.1.running_var',
                  'stage1.3.weight': 'module.layer2.0.weight',
                  'stage1.3.bias': 'module.layer2.0.bias',
                  'stage1.4.weight': 'module.layer2.1.weight',
                  'stage1.4.bias': 'module.layer2.1.bias',
                  'stage1.4.running_mean': 'module.layer2.1.running_mean',
                  'stage1.4.running_var': 'module.layer2.1.running_var',
                  'stage2.0.weight': 'module.layer4.0.weight',
                  'stage2.0.bias': 'module.layer4.0.bias',
                  'stage2.1.weight': 'module.layer4.1.weight',
                  'stage2.1.bias': 'module.layer4.1.bias',
                  'stage2.1.running_mean': 'module.layer4.1.running_mean',
                  'stage2.1.running_var': 'module.layer4.1.running_var',
                  'stage2.3.weight': 'module.layer5.0.weight',
                  'stage2.3.bias': 'module.layer5.0.bias',
                  'stage2.4.weight': 'module.layer5.1.weight',
                  'stage2.4.bias': 'module.layer5.1.bias',
                  'stage2.4.running_mean': 'module.layer5.1.running_mean',
                  'stage2.4.running_var': 'module.layer5.1.running_var',
                  'stage3.0.weight': 'module.layer7.0.weight',
                  'stage3.0.bias': 'module.layer7.0.bias',
                  'stage3.1.weight': 'module.layer7.1.weight',
                  'stage3.1.bias': 'module.layer7.1.bias',
                  'stage3.1.running_mean': 'module.layer7.1.running_mean',
                  'stage3.1.running_var': 'module.layer7.1.running_var',
                  'stage3.3.weight': 'module.layer8.0.weight',
                  'stage3.3.bias': 'module.layer8.0.bias',
                  'stage3.4.weight': 'module.layer8.1.weight',
                  'stage3.4.bias': 'module.layer8.1.bias',
                  'stage3.4.running_mean': 'module.layer8.1.running_mean',
                  'stage3.4.running_var': 'module.layer8.1.running_var',
                  'stage4.0.weight': 'module.layer10.0.weight',
                  'stage4.0.bias': 'module.layer10.0.bias',
                  'stage4.1.weight': 'module.layer10.1.weight',
                  'stage4.1.bias': 'module.layer10.1.bias',
                  'stage4.1.running_mean': 'module.layer10.1.running_mean',
                  'stage4.1.running_var': 'module.layer10.1.running_var',
                  'stage4.3.weight': 'module.layer11.0.weight',
                  'stage4.3.bias': 'module.layer11.0.bias',
                  'stage4.4.weight': 'module.layer11.1.weight',
                  'stage4.4.bias': 'module.layer11.1.bias',
                  'stage4.4.running_mean': 'module.layer11.1.running_mean',
                  'stage4.4.running_var': 'module.layer11.1.running_var',
                  'stage5.0.weight': 'module.layer13.0.weight',
                  'stage5.0.bias': 'module.layer13.0.bias',
                  'stage5.1.weight': 'module.layer13.1.weight',
                  'stage5.1.bias': 'module.layer13.1.bias',
                  'stage5.1.running_mean': 'module.layer13.1.running_mean',
                  'stage5.1.running_var': 'module.layer13.1.running_var',
                  'stage5.3.weight': 'module.layer14.0.weight',
                  'stage5.3.bias': 'module.layer14.0.bias',
                  'stage5.4.weight': 'module.layer14.1.weight',
                  'stage5.4.bias': 'module.layer14.1.bias',
                  'stage5.4.running_mean': 'module.layer14.1.running_mean',
                  'stage5.4.running_var': 'module.layer14.1.running_var',
                  'stage6.0.weight': 'module.layer16.0.weight',
                  'stage6.0.bias': 'module.layer16.0.bias',
                  'stage6.1.weight': 'module.layer16.1.weight',
                  'stage6.1.bias': 'module.layer16.1.bias',
                  'stage6.1.running_mean': 'module.layer16.1.running_mean',
                  'stage6.1.running_var': 'module.layer16.1.running_var',
                  'stage7.0.weight': 'module.layer18.0.weight',
                  'stage7.0.bias': 'module.layer18.0.bias',
                  'stage7.1.weight': 'module.layer18.1.weight',
                  'stage7.1.bias': 'module.layer18.1.bias',
                  'stage7.1.running_mean': 'module.layer18.1.running_mean',
                  'stage7.1.running_var': 'module.layer18.1.running_var'}

net = Net(old_state_dict, state_dict_map)

for key, value in net.state_dict().items():

    if 'num_batches_tracked' in key:
        continue

    param1 = value
    param2 = old_state_dict[state_dict_map[key]]

    is_equal = torch.allclose(param1, param2)

    print(is_equal)

Please accept it as a solution if it solved your problem :slight_smile:

Hi @Valerio_Biscione, thanks a lot for your solution!

However, I am not sure why this is a problem in the first place, since in:

stages = [nn.Sequential(conv1,
                        batch_norm,
                        activation,
                        conv2,
                        batch_norm,
                        activation,
                        pooling)]

The first batch_norm's weights, bias, running mean, and running variance have different names (stage1.1.*) from the second batch_norm's weights, bias, running mean, and running variance (stage1.4.*) in the state_dict. So PyTorch should see these as two different modules right?

The name is different, but I don’t think load_state_dict cares much about the name: when it arrives to stage1.1 it copies the values for that state dict into the batch_norm module. However, when it arrives to stage1.4 it copies the values to the same batch_norm module (as, in your sequential modules, these two modules are in fact the same). I hope that makes sense :slight_smile:
This problem is actually pretty common, not in loading stat_dict, but in training layers (as you can imagine, reusing the same module on different parts of the network can lead to weird results!) so watch out for it!

1 Like

Hi @Valerio_Biscione, thanks a lot for the clarification! It makes sense now.

1 Like