EMA and nn.utils.spectral_norm()

I’m trying to keep track of an exponential moving average (EMA) of my model weights, to be used in inference.
I noticed whenever the model contains spectral normalization (using nn.utils.spectral_norm()), the resulting EMA model does not produce expected results.
As spectral_norm() seems to ‘patch’ the wrapped module, adding some parameters, buffers, attributes, etc., I’m not really sure how these should be handled. Averaging trainable parameters, the entire state dict, removing/applying spectral normalization before/after the averaging update, doesn’t seem to work at least.

What would be the correct way to apply EMA to a module with spectral normalization applied?

Hey, I meet the same problem and question. Does you fix it?