Averaging models with differentiable weighting coefficients

Hi Max!

Pytorch layers such as Linear an Conv2 are pretty insistent about
having their weights be Parameters and leaf tensors. There may be
some way around this that lets you assign a non-leaf tensor, such as
averaged parameters, to be the weight of such a layer, but I don’t
know how to do it.

What you can do is let your “weighting coefficients” be trainable
parameters (with requires_grad = True), compute your averaged
weight tensors as a function of the “weighting coefficients” and the
weight tensors of your model_i. Then pass the averaged weight
tensors into the functional forms for layers such as linear() and conv2().

Gradients will now pass back properly through the functional-form
layers to your trainable “weighting coefficients” parameters, which
you can then optimize.

This has the disadvantage that you have to rewrite your model to use
functional-form layers – you can’t just instantiate a new instance of
your model and, for example, load its state dict with averaged weights.
But I don’t know of a scheme that avoids such a rewrite.

Best.

K. Frank

1 Like