I’m trying to keep track of an exponential moving average (EMA) of my model weights, to be used in inference.

I noticed whenever the model contains spectral normalization (using `nn.utils.spectral_norm()`

), the resulting EMA model does not produce expected results.

As `spectral_norm()`

seems to ‘patch’ the wrapped module, adding some parameters, buffers, attributes, etc., I’m not really sure how these should be handled. Averaging trainable parameters, the entire state dict, removing/applying spectral normalization before/after the averaging update, doesn’t seem to work at least.

What would be the correct way to apply EMA to a module with spectral normalization applied?