Best way to compute the average of multiple models


I have a list of models (nn.Module instances) with same architecture and I need to compute the average model obtained from them, i.e., having the same architecture and where the weights are the average of input models weights.

An easy solution would be to iterate over the keys in state_dict of the target model and fill it using the weights from input models, corresponding to the same key (see code below). I am not sure that this is best way to do it because we need to iterate over the keys (or the modules) of the model which is probably not optimal, and was wondering if there is a faster way to do.

We can do the same, by iterating over the modules instead of the state_dict, but it doesn’t seem to help.

    for key in target_state_dict:
        if target_state_dict[key].data.dtype == torch.float32:
            for model_id, learner in enumerate(models):
                state_dict = model.state_dict()
                target_state_dict[key].data +=  state_dict[key].data.clone()

I think averaging the values in the state_dicts is a valid approach and I’ve also suggested it here in the past.
You should of course check, if this approach is valid at all from the point of view of the trained models.
E.g. if you are trying to calculate the “average model” using completely different training runs, each model could converge to different parameter sets, and I would assume that the average model would yield a bad performance.

Yes I remember this solution from here. I think it’s good in the case of two models, but I don’t think it’s the most efficient in the case of a sequence of many models. Because as you see, both in the code you suggested and in the one I pasted with this question, we are having a double loop, so this is very slow, because of the fact that loops in python are slow.

Normally, having some option that enables averaging models, can be very useful, for example for Stochastic Weighted Averaging or other variations of this method, and I would be surprised if pytorch doesn’t support this feature.

PyTorch does support Stochastic Weight Averaging through some util methods, so you might want to take a look at it.

1 Like