I have a list of models (
nn.Module instances) with same architecture and I need to compute the average model obtained from them, i.e., having the same architecture and where the weights are the average of input models weights.
An easy solution would be to iterate over the keys in
state_dict of the target model and fill it using the weights from input models, corresponding to the same key (see code below). I am not sure that this is best way to do it because we need to iterate over the keys (or the modules) of the model which is probably not optimal, and was wondering if there is a faster way to do.
We can do the same, by iterating over the
modules instead of the
state_dict, but it doesn’t seem to help.
for key in target_state_dict: if target_state_dict[key].data.dtype == torch.float32: target_state_dict[key].data.fill_(0.) for model_id, learner in enumerate(models): state_dict = model.state_dict() target_state_dict[key].data += state_dict[key].data.clone()