I’m trying to keep track of an exponential moving average (EMA) of my model weights, to be used in inference.
I noticed whenever the model contains spectral normalization (using
nn.utils.spectral_norm()), the resulting EMA model does not produce expected results.
spectral_norm() seems to ‘patch’ the wrapped module, adding some parameters, buffers, attributes, etc., I’m not really sure how these should be handled. Averaging trainable parameters, the entire state dict, removing/applying spectral normalization before/after the averaging update, doesn’t seem to work at least.
What would be the correct way to apply EMA to a module with spectral normalization applied?