Creating a model which weights are the sum of weights of 2 different neural networks

I am doing an experiment of transfer learning.
I trained 2 CNNs that have exactly the same structure, one for MNIST and one for SVHN.
I obtained the parameters (weights and bias) of the 2 models.
Now, I want to combine (sum, or other operations) these weights. A thing like this:

modelMNIST.parameters()
modelSVHN.parameters()

#now the new model
model3 = MyCNN(1)
model3.parameters = modelMNIST.parameters()+modelSVHN.parameters()

If I do in this way, I obtain this error:
SyntaxError: can't assign to function call

And in this way:

model3.block_1[0].weight = modelMNIST.block_1[0].weight + modelSVHN.block_1[0].weight

I get this error:

TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

I have to use in some way load_state_dict?

Hi Bruno!

I have both a conceptual comment and a couple of technical comments.

On the conceptual side:

I have concerns about combining weights like this from two
independently-trained models. (I have heard reports that this
works or can be made to work, but I’m skeptical.)

The issue is that there is a lot of arbitrariness and redundancy
in the weights of a neural-network model. If I take two identical
models, but give them different (but equivalent) initializations,
and train them on the same training data (but probably batched
up into different (but equivalent) random batches), there is no
reason for “weight-17” in model A to have the same value as
“weight-17” in model B.

Training model A takes it down some particular path along the
“loss-surface” to some reasonable location in weight-configuration
space. But because of the redundancy and randomness, training
model B takes it down some entirely different (but comparably
worthwhile) path to some entirely different location (that is comparably
good in making predictions).

Speaking somewhat figuratively, let’s say that, by happenstance,
the weights on the “left-hand” side of model A learn to recognize
straight edges, while weights on the “right-hand” side learn curved
edges. But now assume that, by happenstance, this is reversed in
model B, so the left side learns curved edges and the right, straight
edges.

So weight-17 in model A means straight, but weight-17 in model B
means curved. These two weights mean two different things, so it
doesn’t make sense to add, or average, or otherwise combine then
together.

On the technical side:

.parameters is a method (“member function”) of Module, while
.parameters() is the result of invoking that method (calling that
function). You are trying to assign something to the function as the
error message indicates.

But furthermore, the result of calling .parameters() is also not
something you can assign to. It’s a “generator” – something you
can really just iterate over.

So you could simultaneously iterate over the result of calling
.parameters() on your three models.

But …

A Parameter is not a Tensor – it’s a wrapper for a Tensor that
has some added functionality. When you add two Parameters
together, you don’t get a new Parameter – you get a new Tensor
that is the sum of the two wrapped Tensors. If you assign a Tensor
to a (reference to a) Parameter, you don’t get a (reference to a)
Parameter – you get a (reference to a) Tensor (and a Module
is smart enough not to let you do this).

You would need to do something like this:

l1 = torch.nn.Linear (2, 3)
l2 = torch.nn.Linear (2, 3)
l3 = torch.nn.Linear (2, 3)
with torch.no_grad():
    l1.weight.copy_ (l2.weight + l3.weight)

Best.

K. Frank

1 Like

Hi Frank!

I have concerns about combining weights like this from two
independently-trained models. (I have heard reports that this
works or can be made to work, but I’m skeptical.)

Can you send me these papers/reports, if you can find? Because I am doing this experiment for my thesis, and I just started with this.

The issue is that there is a lot of arbitrariness and redundancy
in the weights of a neural-network model. If I take two identical
models, but give them different (but equivalent) initializations,
and train them on the same training data (but probably batched
up into different (but equivalent) random batches), there is no
reason for “weight-17” in model A to have the same value as
“weight-17” in model B.

Yes, I know, but my supervisor assigned me this work, and we want to obtain these bad results. My supervisor has several ideas on how to try to improve later.

For the technical solution: I have modelMNIST, and mondelSVHN
You mean, I have to do this?

with torch.no_grad():
    model3.weight.copy_ (modelMNIST.block_1[0].weight + modelSVHN.block_1[0].weight

Doing it for all the blocks (I have 4 blocks + 1 classifier) and for each layer of the blocks?

Okay, I found the solution:

class VGG16SUM(nn.Module):
    
    def __init__(self, model1, model2, num_classes):
        super(VGG16SUM, self).__init__()

        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      # (1(32-1)- 32 + 3)/2 = 1
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )

        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=64,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )
        
        self.block_3 = nn.Sequential(
            nn.Conv2d(in_channels=128,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )

        self.block_4 = nn.Sequential(
            nn.Conv2d(in_channels=256,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        ) 


        self.classifier = nn.Sequential(
            nn.Linear(2048, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.25),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.25),
            nn.Linear(4096, num_classes),
        )

        for p_out, p_in1, p_in2 in zip(self.parameters(), model1.parameters(), model2.parameters()):
            p_out.data = nn.Parameter(p_in1 +p_in2);

    def forward(self, x):

        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = self.block_4(x)
        # x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
        #logits = self.classifier(x)
        #probas = F.softmax(logits, dim=1)
        # probas = nn.Softmax(logits)
        #return probas
        # return logits

Hi Bruno!

I don’t have any specific references. I have a vague recollection that
I expressed similar skepticism in another thread on this forum and the
poster ended up replying that he got his scheme working.

Some thoughts:

First, a disclaimer: I haven’t tried any relevant experiments myself,
and I don’t know of any literature on this issue.

But let’s say you want to try this and make it work – that is, you want
to train two models with identical architectures independently on two
similar, but distinct problems, and then average the weights together
somehow.

Why might doing something like this have value? My intuition runs as
follows: The upstream weights of a network tend to learn “lower-level”
features that are likely to be common to both problems. For example,
upstream convolutions might learn to detect edges. In contrast, the
downstream weights learn “higher-level” features more tuned to the
specific problem being trained on. For example, even though similar,
a downstream fully-connected layer trained on zipcodes might learn
handwritten digits, but one trained on house numbers, might do better
with block-letter-style printed or engraved digits.

By “averaging” together the upstream weights, you might do a better
(and perhaps more generalizeable) job recognizing the more generic,
lower-level features. (It would seem harder to gain any advantage with
the more-problem-specific features.)

Let’s say this is all true and can be made to work somehow. The
concern I expressed above still remains: After independent training,
because of the redundancies in the network weights, there’s no reason
that weight-17 in model A plays the same role or has the same meaning
as weight-17 in model B. So it doesn’t make sense to combine them
together, using an average or otherwise.

However, what if, by (extremely unlikely) happenstance, model A and
model B, while being trained, did end up following similar paths, and
ended up at similar locations in weight-configuration space. Now it
could make sense to combine the two weight-17s together, as they
could now have similar meaning.

One approach could be to have the two models guide one another
along similar paths while training.

The idea would be that you could start the two models at the same
(random) location – i.e., use the same random initialization for weights
of both models – and then add a loss term that nudges the two models
to prefer similar paths.

Concretely, you could use ((weights_A - weights_B)**2).sum()
as an added loss term for training both models, where, when taking
the optimization step for model A, weights_B are viewed as fixed
(non-trainable) parameters, and vice versa.

If you accept the intuition that the upstream weights are more likely
than downstream weights to play similar roles in the two models, you
might choose to weight the upstream weights more heavily in the
proposed added loss term.

Taking this idea to its logical extreme, you could weight the upstream
weights so heavily that the upstream weights of the two models are
forced to be identical, and give zero weight to the downstream weights
so that they train completely independently.

Of course doing this would just be doing what people already do when
they train two-headed networks where there is a common upstream
part to the network, but two downstream heads that make two sets of
predictions for two different problems that feed to different loss functions
that are added together into a single loss that is then backpropagated.

From this perspective using such a (potentially-weighted) loss term to
guide the two models along similar paths could be viewed as a way
to interpolate between training the two models fully independently,
combining the weights together after the fact. and training a single
model with two heads.

Best.

K. Frank

1 Like

Hi Bruno!

As an answer to your original technical question, this looks fine to me.

Best.

K. Frank.


The above solution does not work for me. I mean I am doing this
with torch.no_grad():
for p_out, p_in1, p_in2 in zip(teacher.parameters(), student.parameters(), teacher.parameters()):
p_out.data = nn.Parameter(p_in1 +p_in2);
But when I am using teacher for my inference this is giving zero outputs. Like this,
tensor([], device=‘cuda:0’, size=(0, 4))
tensor([], device=‘cuda:0’, size=(0, 4))
tensor([], device=‘cuda:0’, size=(0, 4))
tensor([], device=‘cuda:0’, size=(0, 4))
One thing I want to add these teacher and student are same model. The model is “torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)”