The statement is as the title.
torch.nn.utils.remove_spectral_norm
should work.
You could use this approach:
lin = nn.Linear(10, 10)
lin = torch.nn.utils.spectral_norm(lin)
# check spectral_norm params
print(lin.state_dict())
lin = torch.nn.utils.remove_spectral_norm(lin)
# check params again
print(lin.state_dict())
It seems that using the reimplemented torch.nn.utils.parametrizations.spectral_norm
might be missing the corresponding remove_spectral_norm
call, as using torch.nn.utils.remove_spectral_norm
doesn’t work:
lin = nn.Linear(10, 10)
lin = torch.nn.utils.parametrizations.spectral_norm(lin)
# check spectral_norm params
print(lin.state_dict())
lin = torch.nn.utils.remove_spectral_norm(lin)
ValueError: spectral_norm of 'weight' not found in ParametrizedLinear(
in_features=10, out_features=10, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _SpectralNorm()
)
)
)
@albanD do you know, if this method is just missing or if the removal of spectral_norm
via the new API should be done in another way?
The correct way to remove parametrisations (this or any nn.utils.parametrizations
) is through torch.nn.utils.remove_parametrization
: torch.nn.utils.parametrizations.spectral_norm — PyTorch 1.10.0 documentation
As a side note, all this is documented in the parametrizations tutorial: Parametrizations Tutorial — PyTorch Tutorials 1.10.0+cu102 documentation