Linear interpolation between checkpoints in pytorch - Batch Normalization

I want to linearly interpolate between two PyTorch trained model checkpoints. For all layers except the batch normalization, I load the stated dict and simply do the linear inteporaltion as follow:

def interpolate_state_dicts(state_dict_1, state_dict_2, weight):
return {key: (1 - weight) * state_dict_1[key] + weight * state_dict_2[key]
        for key in state_dict_1.keys()}

I do not know if we can simply do the same for BN layer parameters (Weight, Bias, Running mean, running std) or not? I guess it is not that simple, as mean and std are calculated for a specific batch.

The running stats were updated using all training batches, so if you assume that an interpolation of the parameters works fine, it might also work on the running stats.

EDIT: your use case might also be similar to Stochastic Weight Averaging, so you could take a look how the parameters are averaged and how batchnorm layers are treated.