I implement EMA as follows:
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(ema_tau * param.data + (1 - ema_tau) * target_param.data)
However, one of the models that I apply this to which uses batch norm doesn’t work with this kind of parameter copy. I mean, it works, but it doesn’t learn. Other models perform better with EMA, so I’m not sure why this one breaks. It’s a ResNet18 to be specific.