Averaging models with differentiable weighting coefficients

That’s a good idea and I believe torch.nn.utils.parametrize could be provide a good approach for this use case:

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize


class Average(nn.Module):
    def __init__(self, w1, w2):
        super().__init__()
        self.w1 = w1
        self.w2 = w2
    
    def forward(self, X):
        return (self.w1 + self.w2) / 2.


layer1 = nn.Linear(10, 10)
layer2 = nn.Linear(10, 10)

optimizer = torch.optim.Adam(list(layer1.parameters()) + list(layer2.parameters()), lr=1.)

layer3 = nn.Linear(10, 10)
parametrize.register_parametrization(layer3, "weight", Average(layer1.weight, layer2.weight))

# check if parametrization works
print(((layer1.weight + layer2.weight) / 2. - layer3.weight).abs().max())
# tensor(0., grad_fn=<MaxBackward1>)

# original weights are still different
print((layer1.weight - layer2.weight).abs().max())
# tensor(0.5682, grad_fn=<MaxBackward1>)

x = torch.randn(1, 10)
out = layer3(x)
out.mean().backward()

# check for valid gradients
print(layer1.weight.grad.abs().sum())
# tensor(5.5434)
print(layer2.weight.grad.abs().sum())
# tensor(5.5434)

# update parameters
optimizer.step()

# check again if parametrization works
print(((layer1.weight + layer2.weight) / 2. - layer3.weight).abs().max())
# tensor(0., grad_fn=<MaxBackward1>)

# original weights are still different
print((layer1.weight - layer2.weight).abs().max())
# tensor(0.5682, grad_fn=<MaxBackward1>)
4 Likes