Copy weights inside the model

Hello,
I have a multi branch architecture, with 3 branches at the end.
I would like to do a warm start of training, by loading a pre-trained state dictionary which only has 1 branch at the end.
So after loading the state dictionary, I would like to copy the weights of branch1 to branch2 and branch3.

This is what I have so far:

    pretrained_dict  = torch.load('H:/workspace/pretrain/epoch_500_pretrain_resnet50.pth')
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    # copy braches 0 -> branches 1/2
    #model.branches[1].weight.data =  model.branches[0]weight.data ??

The problem is the last line of code, i do not know how to address the models weight data.
I tried to look on forums, but i cant really find a correct way to do this.

Don’t use the .data attribute, as it might yield some (hidden) side effects.
I would recommend to use:

with torch.no_grad():
    model.branches[1].weight.copy_(model.branches[0].weight)

Following your advice i tried to copy with .weight and .bias, but I fail to get results.
After the loading the state dict of a model that only has 1 branch (called branch 0), branch 0 achieves the same result as it should, if i disable the other branches during forward prop. So that branch is confirmed to have correct weights and biases.
But if i allow the images to also go through the branch 1 and branch 2 (which should have the same weights/biases), the performance goes down greatly, meaning I failed to copy the weights and biases, or there could be something else Im missing.

I used the following code to make sure i take into account every parameter:

        for name, child in model.named_children():
            for name2, params in child.named_parameters():
                print(name, name2)

Then a wrote a pretty long manual copy code:
(please ignore incorrect comment names of resnet parts)

copy branch 0 -> branches 1/2

        with torch.no_grad():
            # copy to branch 1
            ## block 0
            ### copy bottleneck 0
            model.branches[1][0][0].conv1.weight.        copy_(model.branches[0][0][0].conv1.weight)
            model.branches[1][0][0].bn1.weight.          copy_(model.branches[0][0][0].bn1.weight)
            model.branches[1][0][0].bn1.bias.            copy_(model.branches[0][0][0].bn1.bias)
            model.branches[1][0][0].conv2.weight.        copy_(model.branches[0][0][0].conv2.weight)
            model.branches[1][0][0].bn2.weight.          copy_(model.branches[0][0][0].bn2.weight)
            model.branches[1][0][0].bn2.bias.            copy_(model.branches[0][0][0].bn2.bias)
            model.branches[1][0][0].conv3.weight.        copy_(model.branches[0][0][0].conv3.weight)
            model.branches[1][0][0].bn3.weight.          copy_(model.branches[0][0][0].bn3.weight)
            model.branches[1][0][0].bn3.bias.            copy_(model.branches[0][0][0].bn3.bias)
            model.branches[1][0][0].downsample[0].weight.copy_(model.branches[0][0][0].downsample[0].weight)
            model.branches[1][0][0].downsample[1].weight.copy_(model.branches[0][0][0].downsample[1].weight)
            model.branches[1][0][0].downsample[1].bias.  copy_(model.branches[0][0][0].downsample[1].bias)
            ### bottleneck 1
            model.branches[1][0][1].conv1.weight.copy_(model.branches[0][0][1].conv1.weight)
            model.branches[1][0][1].bn1.weight.  copy_(model.branches[0][0][1].bn1.weight)
            model.branches[1][0][1].bn1.bias.    copy_(model.branches[0][0][1].bn1.bias)
            model.branches[1][0][1].conv2.weight.copy_(model.branches[0][0][1].conv2.weight)
            model.branches[1][0][1].bn2.weight.  copy_(model.branches[0][0][1].bn2.weight)
            model.branches[1][0][1].bn2.bias.    copy_(model.branches[0][0][1].bn2.bias)
            model.branches[1][0][1].conv3.weight.copy_(model.branches[0][0][1].conv3.weight)
            model.branches[1][0][1].bn3.weight.  copy_(model.branches[0][0][1].bn3.weight)
            model.branches[1][0][1].bn3.bias.    copy_(model.branches[0][0][1].bn3.bias)
            ### bottleneck 2
            model.branches[1][0][2].conv1.weight.copy_(model.branches[0][0][2].conv1.weight)
            model.branches[1][0][2].bn1.weight.  copy_(model.branches[0][0][2].bn1.weight)
            model.branches[1][0][2].bn1.bias.    copy_(model.branches[0][0][2].bn1.bias)
            model.branches[1][0][2].conv2.weight.copy_(model.branches[0][0][2].conv2.weight)
            model.branches[1][0][2].bn2.weight.  copy_(model.branches[0][0][2].bn2.weight)
            model.branches[1][0][2].bn2.bias.    copy_(model.branches[0][0][2].bn2.bias)
            model.branches[1][0][2].conv3.weight.copy_(model.branches[0][0][2].conv3.weight)
            model.branches[1][0][2].bn3.weight.  copy_(model.branches[0][0][2].bn3.weight)
            model.branches[1][0][2].bn3.bias.    copy_(model.branches[0][0][2].bn3.bias)
            ### bottleneck 3
            model.branches[1][0][3].conv1.weight.copy_(model.branches[0][0][3].conv1.weight)
            model.branches[1][0][3].bn1.weight.  copy_(model.branches[0][0][3].bn1.weight)
            model.branches[1][0][3].bn1.bias.    copy_(model.branches[0][0][3].bn1.bias)
            model.branches[1][0][3].conv2.weight.copy_(model.branches[0][0][3].conv2.weight)
            model.branches[1][0][3].bn2.weight.  copy_(model.branches[0][0][3].bn2.weight)
            model.branches[1][0][3].bn2.bias.    copy_(model.branches[0][0][3].bn2.bias)
            model.branches[1][0][3].conv3.weight.copy_(model.branches[0][0][3].conv3.weight)
            model.branches[1][0][3].bn3.weight.  copy_(model.branches[0][0][3].bn3.weight)
            model.branches[1][0][3].bn3.bias.    copy_(model.branches[0][0][3].bn3.bias)
            ### bottleneck 4
            model.branches[1][0][4].conv1.weight.copy_(model.branches[0][0][4].conv1.weight)
            model.branches[1][0][4].bn1.weight.  copy_(model.branches[0][0][4].bn1.weight)
            model.branches[1][0][4].bn1.bias.    copy_(model.branches[0][0][4].bn1.bias)
            model.branches[1][0][4].conv2.weight.copy_(model.branches[0][0][4].conv2.weight)
            model.branches[1][0][4].bn2.weight.  copy_(model.branches[0][0][4].bn2.weight)
            model.branches[1][0][4].bn2.bias.    copy_(model.branches[0][0][4].bn2.bias)
            model.branches[1][0][4].conv3.weight.copy_(model.branches[0][0][4].conv3.weight)
            model.branches[1][0][4].bn3.weight.  copy_(model.branches[0][0][4].bn3.weight)
            model.branches[1][0][4].bn3.bias.    copy_(model.branches[0][0][4].bn3.bias)
            ### bottleneck 5
            model.branches[1][0][5].conv1.weight.copy_(model.branches[0][0][5].conv1.weight)
            model.branches[1][0][5].bn1.weight.  copy_(model.branches[0][0][5].bn1.weight)
            model.branches[1][0][5].bn1.bias.    copy_(model.branches[0][0][5].bn1.bias)
            model.branches[1][0][5].conv2.weight.copy_(model.branches[0][0][5].conv2.weight)
            model.branches[1][0][5].bn2.weight.  copy_(model.branches[0][0][5].bn2.weight)
            model.branches[1][0][5].bn2.bias.    copy_(model.branches[0][0][5].bn2.bias)
            model.branches[1][0][5].conv3.weight.copy_(model.branches[0][0][5].conv3.weight)
            model.branches[1][0][5].bn3.weight.  copy_(model.branches[0][0][5].bn3.weight)
            model.branches[1][0][5].bn3.bias.    copy_(model.branches[0][0][5].bn3.bias)

            ## block1
            ### bottleneck 0
            model.branches[1][1][0].conv1.weight.        copy_(model.branches[0][1][0].conv1.weight)
            model.branches[1][1][0].bn1.weight.          copy_(model.branches[0][1][0].bn1.weight)
            model.branches[1][1][0].bn1.bias.            copy_(model.branches[0][1][0].bn1.bias)
            model.branches[1][1][0].conv2.weight.        copy_(model.branches[0][1][0].conv2.weight)
            model.branches[1][1][0].bn2.weight.          copy_(model.branches[0][1][0].bn2.weight)
            model.branches[1][1][0].bn2.bias.            copy_(model.branches[0][1][0].bn2.bias)
            model.branches[1][1][0].conv3.weight.        copy_(model.branches[0][1][0].conv3.weight)
            model.branches[1][1][0].bn3.weight.          copy_(model.branches[0][1][0].bn3.weight)
            model.branches[1][1][0].bn3.bias.            copy_(model.branches[0][1][0].bn3.bias)
            model.branches[1][1][0].downsample[0].weight.copy_(model.branches[0][1][0].downsample[0].weight)
            model.branches[1][1][0].downsample[1].weight.copy_(model.branches[0][1][0].downsample[1].weight)
            model.branches[1][1][0].downsample[1].bias.  copy_(model.branches[0][1][0].downsample[1].bias)
            ### bottleneck 1
            model.branches[1][1][1].conv1.weight.copy_(model.branches[0][1][1].conv1.weight)
            model.branches[1][1][1].bn1.weight.  copy_(model.branches[0][1][1].bn1.weight)
            model.branches[1][1][1].bn1.bias.    copy_(model.branches[0][1][1].bn1.bias)
            model.branches[1][1][1].conv2.weight.copy_(model.branches[0][1][1].conv2.weight)
            model.branches[1][1][1].bn2.weight.  copy_(model.branches[0][1][1].bn2.weight)
            model.branches[1][1][1].bn2.bias.    copy_(model.branches[0][1][1].bn2.bias)
            model.branches[1][1][1].conv3.weight.copy_(model.branches[0][1][1].conv3.weight)
            model.branches[1][1][1].bn3.weight.  copy_(model.branches[0][1][1].bn3.weight)
            model.branches[1][1][1].bn3.bias.    copy_(model.branches[0][1][1].bn3.bias)
            ### bottleneck 2
            model.branches[1][1][2].conv1.weight.copy_(model.branches[0][1][2].conv1.weight)
            model.branches[1][1][2].bn1.weight.  copy_(model.branches[0][1][2].bn1.weight)
            model.branches[1][1][2].bn1.bias.    copy_(model.branches[0][1][2].bn1.bias)
            model.branches[1][1][2].conv2.weight.copy_(model.branches[0][1][2].conv2.weight)
            model.branches[1][1][2].bn2.weight.  copy_(model.branches[0][1][2].bn2.weight)
            model.branches[1][1][2].bn2.bias.    copy_(model.branches[0][1][2].bn2.bias)
            model.branches[1][1][2].conv3.weight.copy_(model.branches[0][1][2].conv3.weight)
            model.branches[1][1][2].bn3.weight.  copy_(model.branches[0][1][2].bn3.weight)
            model.branches[1][1][2].bn3.bias.    copy_(model.branches[0][1][2].bn3.bias)

            # copy to branch 2
            ## block 0
            ### copy bottleneck 0
            model.branches[2][0][0].conv1.weight.        copy_(model.branches[0][0][0].conv1.weight)
            model.branches[2][0][0].bn1.weight.          copy_(model.branches[0][0][0].bn1.weight)
            model.branches[2][0][0].bn1.bias.            copy_(model.branches[0][0][0].bn1.bias)
            model.branches[2][0][0].conv2.weight.        copy_(model.branches[0][0][0].conv2.weight)
            model.branches[2][0][0].bn2.weight.          copy_(model.branches[0][0][0].bn2.weight)
            model.branches[2][0][0].bn2.bias.            copy_(model.branches[0][0][0].bn2.bias)
            model.branches[2][0][0].conv3.weight.        copy_(model.branches[0][0][0].conv3.weight)
            model.branches[2][0][0].bn3.weight.          copy_(model.branches[0][0][0].bn3.weight)
            model.branches[2][0][0].bn3.bias.            copy_(model.branches[0][0][0].bn3.bias)
            model.branches[2][0][0].downsample[0].weight.copy_(model.branches[0][0][0].downsample[0].weight)
            model.branches[2][0][0].downsample[1].weight.copy_(model.branches[0][0][0].downsample[1].weight)
            model.branches[2][0][0].downsample[1].bias.  copy_(model.branches[0][0][0].downsample[1].bias)
            ### bottleneck 1
            model.branches[2][0][1].conv1.weight.copy_(model.branches[0][0][1].conv1.weight)
            model.branches[2][0][1].bn1.weight.  copy_(model.branches[0][0][1].bn1.weight)
            model.branches[2][0][1].bn1.bias.    copy_(model.branches[0][0][1].bn1.bias)
            model.branches[2][0][1].conv2.weight.copy_(model.branches[0][0][1].conv2.weight)
            model.branches[2][0][1].bn2.weight.  copy_(model.branches[0][0][1].bn2.weight)
            model.branches[2][0][1].bn2.bias.    copy_(model.branches[0][0][1].bn2.bias)
            model.branches[2][0][1].conv3.weight.copy_(model.branches[0][0][1].conv3.weight)
            model.branches[2][0][1].bn3.weight.  copy_(model.branches[0][0][1].bn3.weight)
            model.branches[2][0][1].bn3.bias.    copy_(model.branches[0][0][1].bn3.bias)
            ### bottleneck 2
            model.branches[2][0][2].conv1.weight.copy_(model.branches[0][0][2].conv1.weight)
            model.branches[2][0][2].bn1.weight.  copy_(model.branches[0][0][2].bn1.weight)
            model.branches[2][0][2].bn1.bias.    copy_(model.branches[0][0][2].bn1.bias)
            model.branches[2][0][2].conv2.weight.copy_(model.branches[0][0][2].conv2.weight)
            model.branches[2][0][2].bn2.weight.  copy_(model.branches[0][0][2].bn2.weight)
            model.branches[2][0][2].bn2.bias.    copy_(model.branches[0][0][2].bn2.bias)
            model.branches[2][0][2].conv3.weight.copy_(model.branches[0][0][2].conv3.weight)
            model.branches[2][0][2].bn3.weight.  copy_(model.branches[0][0][2].bn3.weight)
            model.branches[2][0][2].bn3.bias.    copy_(model.branches[0][0][2].bn3.bias)
            ### bottleneck 3
            model.branches[2][0][3].conv1.weight.copy_(model.branches[0][0][3].conv1.weight)
            model.branches[2][0][3].bn1.weight.  copy_(model.branches[0][0][3].bn1.weight)
            model.branches[2][0][3].bn1.bias.    copy_(model.branches[0][0][3].bn1.bias)
            model.branches[2][0][3].conv2.weight.copy_(model.branches[0][0][3].conv2.weight)
            model.branches[2][0][3].bn2.weight.  copy_(model.branches[0][0][3].bn2.weight)
            model.branches[2][0][3].bn2.bias.    copy_(model.branches[0][0][3].bn2.bias)
            model.branches[2][0][3].conv3.weight.copy_(model.branches[0][0][3].conv3.weight)
            model.branches[2][0][3].bn3.weight.  copy_(model.branches[0][0][3].bn3.weight)
            model.branches[2][0][3].bn3.bias.    copy_(model.branches[0][0][3].bn3.bias)
            ### bottleneck 4
            model.branches[2][0][4].conv1.weight.copy_(model.branches[0][0][4].conv1.weight)
            model.branches[2][0][4].bn1.weight.  copy_(model.branches[0][0][4].bn1.weight)
            model.branches[2][0][4].bn1.bias.    copy_(model.branches[0][0][4].bn1.bias)
            model.branches[2][0][4].conv2.weight.copy_(model.branches[0][0][4].conv2.weight)
            model.branches[2][0][4].bn2.weight.  copy_(model.branches[0][0][4].bn2.weight)
            model.branches[2][0][4].bn2.bias.    copy_(model.branches[0][0][4].bn2.bias)
            model.branches[2][0][4].conv3.weight.copy_(model.branches[0][0][4].conv3.weight)
            model.branches[2][0][4].bn3.weight.  copy_(model.branches[0][0][4].bn3.weight)
            model.branches[2][0][4].bn3.bias.    copy_(model.branches[0][0][4].bn3.bias)
            ### bottleneck 5
            model.branches[2][0][5].conv1.weight.copy_(model.branches[0][0][5].conv1.weight)
            model.branches[2][0][5].bn1.weight.  copy_(model.branches[0][0][5].bn1.weight)
            model.branches[2][0][5].bn1.bias.    copy_(model.branches[0][0][5].bn1.bias)
            model.branches[2][0][5].conv2.weight.copy_(model.branches[0][0][5].conv2.weight)
            model.branches[2][0][5].bn2.weight.  copy_(model.branches[0][0][5].bn2.weight)
            model.branches[2][0][5].bn2.bias.    copy_(model.branches[0][0][5].bn2.bias)
            model.branches[2][0][5].conv3.weight.copy_(model.branches[0][0][5].conv3.weight)
            model.branches[2][0][5].bn3.weight.  copy_(model.branches[0][0][5].bn3.weight)
            model.branches[2][0][5].bn3.bias.    copy_(model.branches[0][0][5].bn3.bias)
            ## block2
            ### bottleneck 0
            model.branches[2][1][0].conv1.weight.        copy_(model.branches[0][1][0].conv1.weight)
            model.branches[2][1][0].bn1.weight.          copy_(model.branches[0][1][0].bn1.weight)
            model.branches[2][1][0].bn1.bias.            copy_(model.branches[0][1][0].bn1.bias)
            model.branches[2][1][0].conv2.weight.        copy_(model.branches[0][1][0].conv2.weight)
            model.branches[2][1][0].bn2.weight.          copy_(model.branches[0][1][0].bn2.weight)
            model.branches[2][1][0].bn2.bias.            copy_(model.branches[0][1][0].bn2.bias)
            model.branches[2][1][0].conv3.weight.        copy_(model.branches[0][1][0].conv3.weight)
            model.branches[2][1][0].bn3.weight.          copy_(model.branches[0][1][0].bn3.weight)
            model.branches[2][1][0].bn3.bias.            copy_(model.branches[0][1][0].bn3.bias)
            model.branches[2][1][0].downsample[0].weight.copy_(model.branches[0][1][0].downsample[0].weight)
            model.branches[2][1][0].downsample[1].weight.copy_(model.branches[0][1][0].downsample[1].weight)
            model.branches[2][1][0].downsample[1].bias.  copy_(model.branches[0][1][0].downsample[1].bias)
            ### bottleneck 1
            model.branches[2][1][1].conv1.weight.copy_(model.branches[0][1][1].conv1.weight)
            model.branches[2][1][1].bn1.weight.  copy_(model.branches[0][1][1].bn1.weight)
            model.branches[2][1][1].bn1.bias.    copy_(model.branches[0][1][1].bn1.bias)
            model.branches[2][1][1].conv2.weight.copy_(model.branches[0][1][1].conv2.weight)
            model.branches[2][1][1].bn2.weight.  copy_(model.branches[0][1][1].bn2.weight)
            model.branches[2][1][1].bn2.bias.    copy_(model.branches[0][1][1].bn2.bias)
            model.branches[2][1][1].conv3.weight.copy_(model.branches[0][1][1].conv3.weight)
            model.branches[2][1][1].bn3.weight.  copy_(model.branches[0][1][1].bn3.weight)
            model.branches[2][1][1].bn3.bias.    copy_(model.branches[0][1][1].bn3.bias)
            ### bottleneck 2
            model.branches[2][1][2].conv1.weight.copy_(model.branches[0][1][2].conv1.weight)
            model.branches[2][1][2].bn1.weight.  copy_(model.branches[0][1][2].bn1.weight)
            model.branches[2][1][2].bn1.bias.    copy_(model.branches[0][1][2].bn1.bias)
            model.branches[2][1][2].conv2.weight.copy_(model.branches[0][1][2].conv2.weight)
            model.branches[2][1][2].bn2.weight.  copy_(model.branches[0][1][2].bn2.weight)
            model.branches[2][1][2].bn2.bias.    copy_(model.branches[0][1][2].bn2.bias)
            model.branches[2][1][2].conv3.weight.copy_(model.branches[0][1][2].conv3.weight)
            model.branches[2][1][2].bn3.weight.  copy_(model.branches[0][1][2].bn3.weight)
            model.branches[2][1][2].bn3.bias.    copy_(model.branches[0][1][2].bn3.bias)

            # copy fcs
            model.fcs[1].weight.copy_(model.fcs[0].weight) 
            model.fcs[1].bias.  copy_(model.fcs[0].bias)
            model.fcs[2].weight.copy_(model.fcs[0].weight)
            model.fcs[2].bias.  copy_(model.fcs[0].bias)

You could have to clone the running stats for each batch norm layer as well.
Otherwise the initial statistics will be used in the other branch.

This sounds like using the .data attribute.
If i wish to train these other branches for similar but slightly different classification, would you suggest copying batch norm running stats?
Since you mentioned hidden side effects, im not sure if it is a good idea or not.

Don’t use the .data attribute.

If you are copying the trainable parameters to side branches, I would recommend to copy everything (all parameters and buffers) to these branches. This might give you a “good initialization” for the training of these classifiers.

I’m not sure, if this approach is better train training from scratch, but I can see the similarity to using a pretrained model for a new classification task.

I did do the running_mean and running_var for bn layers, the result is closer if i propagate the same tensor through all branches but still incorrect.

Would it be easier to make 3 instances of the same model, load the state dictionary (which holds trained branch 0 ) and somehow merge the branch 0-s together and rename them 1/2? Altought not sure how could I do that either

Could you post your model architecture so that we could have a look?
The mentioned approach sounds simpler and if you would just like to load the same parameters and buffers, this dummy code shows how to do so using submodules:

class MyModel(nn.Module):
    def __init__(self, moduleA, moduleB, moduleC):
        super(MyModel, self).__init__()
        self.moduleA = moduleA
        self.moduleB = moduleB
        self.moduleC = moduleC
        
    def forward(self, x):
        #...
        return x

moduleA = models.resnet18()
moduleB = models.resnet18()
moduleC = models.resnet18()

moduleB.load_state_dict(moduleA.state_dict())
moduleC.load_state_dict(moduleA.state_dict())

model = MyModel(moduleA, moduleB, moduleC)

However, your use case seems to be more complicated, so we would need some more information.

Hey, sorry for the late reply.
Indeed I found a post of yours about model ensemble somewhere else, and decided to go with creating 3 instances of the same model, load all their state dictionaries in the usual way:

define 3 submodels

model0 = PABN()
model1 = PABN()
model2 = PABN()
# load state dict of 3 models
   checkpoint = torch.load('epoch_500.pth')
model0.load_state_dict(checkpoint['model_state_dict'], strict=False)
model1.load_state_dict(checkpoint['model_state_dict'], strict=False)
model2.load_state_dict(checkpoint['model_state_dict'], strict=False)
# freeze backbones
freeze_backbones = True
if freeze_backbones:
    with torch.no_grad():
        # freeze parameters
        for name, child in model0.named_children():
            if name == 'backbone':
                for name2, params in child.named_parameters():
                    params.requires_grad = False
        for name, child in model1.named_children():
            if name == 'backbone':
                for name2, params in child.named_parameters():
                    params.requires_grad = False
        for name, child in model2.named_children():
            if name == 'backbone':
                for name2, params in child.named_parameters():
                    params.requires_grad = False
# create a new model ensemble
model = EnsembleModels(model0, model1, model2)

This ended up in just as it would be an actual branching tree network. There are 3 different backbones, but they are frozen, so they do not change and stay exactly the same. The downside is the higher memory usage I guess.
image

Note that if i did not freeze the backbone i got a lower performance of the ensemble model.

But currently I would like to further improve my architecture and branch my network another time:

And here I have the problem again. And this time I would like it again if i could copy weights and biases in a reliable way, but i still dont know how to do it. When I used the previously mentioned methods (copy weigth data + bias data), i got a close but still not same performance, therefore i cannot trust it, and the training after that resulted in lower performance as I remember, compared to the case of loading with state dictionary and ensembling with frozen conv0-3.
Since I know that training the conv4+conv5 resulted in better perfomance, Im not sure about freezing conv4 as well, as I have the feeling that those higher level features do need some re-trainign in order to specialize for their data. So I cannot really use the same method, when i froze the conv0-3. I think freezing conv0-3 was okay because they have low level feratures that doesnt really need specialization too much.