How to implement spectral normalization in pytorch

has anyone implemented spectral normalization(https://openreview.net/forum?id=B1QRgziT-) using pytorch. In the original paper, they propose spectral normalization on the weights of GAN, and achieve good results. The training algorithm is below:


Is there a simple way to compute spectral norm of the weights Wsn? I note that one naive way is to fetch the parameters of a module out and do further processing, or just wrap the module in a forward hooks (i doubt it is okay when doing backward propagation)? can anyone give some help?

I think the best way to implement this is to write a custom optimizer. You can start with SGD: https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py.

Thank you very much for your insights. However, when wrapping it in SGD, I find some problem to be solved. During forward propagation, it will use W instead of Wsn to compute the output, it is problematic because from the algorithm pseudocode, it is using Wsn to compute the output. any ideas?

maybe we could implement it like the weight norm code ?(https://gist.github.com/rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f)

I see. So the forward is supposed to use normalized weights? You can still do it as an optimizer though. Just store the unnormalized weights in optimizer state dict and in each step update the model parameters to normalized weights.

here is its chainer code for max singular value approximation.
(https://github.com/pfnet-research/chainer-gan-lib/tree/master/common/sn)

1 Like

Thank you very much for you answer, and i have implemented one based on the chainer code.

Hello @Xinwei_he,

A small doubt here, Are you adding spectral Norm of the weight as an extra regularizer ? Or Are you embedding it with SGD algorithm itself?

Actually I am trying to see the effect of Spectral norm of weight as an regularizer (like l2 loss). Would your approach be useful in my case?( Since the function calculating spectral norm wouldn’t be differentiable), And when I am adding to the total loss function, I should add something is differentiable.

Thanks!