Averaging models with differentiable weighting coefficients

Suppose I have a couple of models given as state_dicts and instead of optimizing them further, I’d like for every batch to average these models (their weights!) with a weighting coefficient such as 0.1model_1 + 0.4model_2 +…

However, I do not intend to optimize model_i further, I rather like to perform forward- and backward with the averaged model and then update the coefficients.

How can I average the models such that the grad_fn w.r.t. the coefficient tensor persists? I fail at doing this, I always lose the gradients as soon as I assign the averaged parameters (which have valid grad_fns) to the base model used for inference.

Thanks!

1 Like

Hi Max!

Pytorch layers such as Linear an Conv2 are pretty insistent about
having their weights be Parameters and leaf tensors. There may be
some way around this that lets you assign a non-leaf tensor, such as
averaged parameters, to be the weight of such a layer, but I don’t
know how to do it.

What you can do is let your “weighting coefficients” be trainable
parameters (with requires_grad = True), compute your averaged
weight tensors as a function of the “weighting coefficients” and the
weight tensors of your model_i. Then pass the averaged weight
tensors into the functional forms for layers such as linear() and conv2().

Gradients will now pass back properly through the functional-form
layers to your trainable “weighting coefficients” parameters, which
you can then optimize.

This has the disadvantage that you have to rewrite your model to use
functional-form layers – you can’t just instantiate a new instance of
your model and, for example, load its state dict with averaged weights.
But I don’t know of a scheme that avoids such a rewrite.

Best.

K. Frank

1 Like

That’s a good idea and I believe torch.nn.utils.parametrize could be provide a good approach for this use case:

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize


class Average(nn.Module):
    def __init__(self, w1, w2):
        super().__init__()
        self.w1 = w1
        self.w2 = w2
    
    def forward(self, X):
        return (self.w1 + self.w2) / 2.


layer1 = nn.Linear(10, 10)
layer2 = nn.Linear(10, 10)

optimizer = torch.optim.Adam(list(layer1.parameters()) + list(layer2.parameters()), lr=1.)

layer3 = nn.Linear(10, 10)
parametrize.register_parametrization(layer3, "weight", Average(layer1.weight, layer2.weight))

# check if parametrization works
print(((layer1.weight + layer2.weight) / 2. - layer3.weight).abs().max())
# tensor(0., grad_fn=<MaxBackward1>)

# original weights are still different
print((layer1.weight - layer2.weight).abs().max())
# tensor(0.5682, grad_fn=<MaxBackward1>)

x = torch.randn(1, 10)
out = layer3(x)
out.mean().backward()

# check for valid gradients
print(layer1.weight.grad.abs().sum())
# tensor(5.5434)
print(layer2.weight.grad.abs().sum())
# tensor(5.5434)

# update parameters
optimizer.step()

# check again if parametrization works
print(((layer1.weight + layer2.weight) / 2. - layer3.weight).abs().max())
# tensor(0., grad_fn=<MaxBackward1>)

# original weights are still different
print((layer1.weight - layer2.weight).abs().max())
# tensor(0.5682, grad_fn=<MaxBackward1>)
4 Likes

Thanks, I think I got it to work in a similar fashion!

Unfortunately, I ran into a problem with the current way of doing it. When working on a single GPU, everything works fine. As soon as I get into DataParallel mode, I get errors during inference that input and weight tensors do not lie on the same device - after trying out several things I thought that this makes sense, since my model is now DataParallel and my Parametrization (in your example Average) lies on a single GPU.

I tried the following to also have it in DataParallel mode (adapted to your example):

reparametrization_instance = Average(layer1.weight, layer2.weight)
reparametrization_instance = torch.nn.DataParallel(reparametrization_instance)
parametrize.register_parametrization(module, 'weight', reparametrization_instance)

However, I get the following error:

Registering a parametrization may not change the shape of the tensor, unless unsafe flag is enabled.
unparametrized shape: torch.Size([64, 3, 7, 7])
parametrized shape: torch.Size([128, 3, 7, 7])

Do you have any advice? Thanks a lot!

I don’t know if parametrizations are supported with nn.DataParallel as it’s deprecated and we generally recommend using DistributedDataParallel instead, which I assume should work fine.