I want to linearly interpolate between two PyTorch trained model checkpoints. For all layers except the batch normalization, I load the stated dict and simply do the linear inteporaltion as follow:
def interpolate_state_dicts(state_dict_1, state_dict_2, weight):
return {key: (1 - weight) * state_dict_1[key] + weight * state_dict_2[key]
for key in state_dict_1.keys()}
I do not know if we can simply do the same for BN layer parameters (Weight, Bias, Running mean, running std) or not? I guess it is not that simple, as mean and std are calculated for a specific batch.