Averaging some model parameters

I have k classes, for each of which I have trained a model. I would like to create a new model whose architecture is identical and whose parameters are the weighted average of the corresponding parameters of each of my k models according to some specified distribution p (e.g. [0.1, 0.2, 0.7] for k=3).

My goal is to construct such a model, evaluate it on data, and then have gradient propagate back to the original models (the intermediate averaged model does not matter to me). Would someone be able to help me understand the best way to do this?

To simply average the parameters of multiple models you could use this approach. However, based on your description you would like to calculate the gradients for the parameters in the original models, so using the averaged state_dict wouldn’t work.
Instead I think the right approach would be to create the new averaged tensors from all models and use the functional API for these new tensors.
Something like this would work:

model1 = nn.Linear(10, 10)
model2 = nn.Linear(10, 10)

new_weight = (model1.weight + model2.weight) / 2.
new_bias = (model1.bias + model2.bias) / 2.

x = torch.randn(1, 10)
out = F.linear(x, new_weight, new_bias)

out.mean().backward()

for name, param in model1.named_parameters():
    print(name, param.grad)
    
for name, param in model2.named_parameters():
    print(name, param.grad)

Although it seems to be quite cumbersome, so maybe someone else would have a better and cleaner approach.

Thanks so much! This certainly works for very simple models but I agree is extremely cumbersome for larger ones. Do you know why it is that nn.Parameters don’t retain the gradient of the tensors which are used to create then?

nn.Parameter would recreate a new parameter without any gradient history. While this new parameter would then get the gradients the two tensors used to create it wouldn’t.

Yeah; a bit of searching around showed me that this is still an open issue. Hopefully this is a feature added soon. For now the thread is closed. Thanks again @ptrblck !

If there are “real” models, this approach is infeasible (e.g. there is no “attention” functional and you need to reimplement everything anew). Is there another way someone figured for that?