How to remove SpectralNorm parametrizations in PyTorch?

The statement is as the title.

torch.nn.utils.remove_spectral_norm should work.

Could you give an example on like how to remove ParametrizedLinear with _SpectralNorm? Thanks

You could use this approach:

lin = nn.Linear(10, 10)
lin = torch.nn.utils.spectral_norm(lin)

# check spectral_norm params
print(lin.state_dict())

lin = torch.nn.utils.remove_spectral_norm(lin)

# check params again
print(lin.state_dict())

It seems that using the reimplemented torch.nn.utils.parametrizations.spectral_norm might be missing the corresponding remove_spectral_norm call, as using torch.nn.utils.remove_spectral_norm doesn’t work:

lin = nn.Linear(10, 10)
lin = torch.nn.utils.parametrizations.spectral_norm(lin)

# check spectral_norm params
print(lin.state_dict())

lin = torch.nn.utils.remove_spectral_norm(lin)
ValueError: spectral_norm of 'weight' not found in ParametrizedLinear(
  in_features=10, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): _SpectralNorm()
    )
  )
)

@albanD do you know, if this method is just missing or if the removal of spectral_norm via the new API should be done in another way?

1 Like

The correct way to remove parametrisations (this or any nn.utils.parametrizations) is through torch.nn.utils.remove_parametrization: torch.nn.utils.parametrizations.spectral_norm — PyTorch 1.10.0 documentation

As a side note, all this is documented in the parametrizations tutorial: Parametrizations Tutorial — PyTorch Tutorials 1.10.0+cu102 documentation

2 Likes