That’s a good idea and I believe torch.nn.utils.parametrize
could be provide a good approach for this use case:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
class Average(nn.Module):
def __init__(self, w1, w2):
super().__init__()
self.w1 = w1
self.w2 = w2
def forward(self, X):
return (self.w1 + self.w2) / 2.
layer1 = nn.Linear(10, 10)
layer2 = nn.Linear(10, 10)
optimizer = torch.optim.Adam(list(layer1.parameters()) + list(layer2.parameters()), lr=1.)
layer3 = nn.Linear(10, 10)
parametrize.register_parametrization(layer3, "weight", Average(layer1.weight, layer2.weight))
# check if parametrization works
print(((layer1.weight + layer2.weight) / 2. - layer3.weight).abs().max())
# tensor(0., grad_fn=<MaxBackward1>)
# original weights are still different
print((layer1.weight - layer2.weight).abs().max())
# tensor(0.5682, grad_fn=<MaxBackward1>)
x = torch.randn(1, 10)
out = layer3(x)
out.mean().backward()
# check for valid gradients
print(layer1.weight.grad.abs().sum())
# tensor(5.5434)
print(layer2.weight.grad.abs().sum())
# tensor(5.5434)
# update parameters
optimizer.step()
# check again if parametrization works
print(((layer1.weight + layer2.weight) / 2. - layer3.weight).abs().max())
# tensor(0., grad_fn=<MaxBackward1>)
# original weights are still different
print((layer1.weight - layer2.weight).abs().max())
# tensor(0.5682, grad_fn=<MaxBackward1>)