Gradient penalty with respect to the network parameters

I would like to use the gradient penalty introduced in the NIPS paper https://arxiv.org/abs/1706.04156. In the paper, whenever updating the generator, the gradient penalty for the discriminator parameters is considered.

I know there is a discussion in the previous post (https://discuss.pytorch.org/t/how-to-implement-gradient-penalty-in-pytorch/1656?u=zhangboknight, where the gradient penalty is calculated with respect to the data points

For my case the gradient penalty is calculated with respect to the discriminator parameters:

The authors of the paper gives an implementation through tf.gradients(V, discriminator_vars) in tensorflow (https://github.com/locuslab/gradient_regularized_gan/blob/c1f61272f6176d4d13016779dbe730b346ae21a1/gaussian-toy-regularized.py#L112)

Is there a similar implementation in Pytorch? Can I still use torch.autograd.grad to calculate the gradient wrt. discriminator params? Or can I use .retain_grad() as suggested in https://discuss.pytorch.org/t/how-do-i-calculate-the-gradients-of-a-non-leaf-variable-w-r-t-to-a-loss-function/5112/2?u=zhangboknight? I am not pretty sure. Since this technique helps stabilize the GAN training greatly, the same problem may help others who use Pytorch. Many thanks!

I think autograd.grad with create_graph=True would work for you.

2 Likes

Thanks a lot! Now I can calculate the norm of gradient through:
grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True)
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()

Another reference is the post: https://discuss.pytorch.org/t/hessian-vector-product-with-create-graph-flag/8357?u=zhangboknight and https://discuss.pytorch.org/t/issues-computing-hessian-vector-product/2709/7?u=zhangboknight
I will penalize the gradient using the code and see the result.

2 Likes

Hello Tony,

thanks for highlighting the paper!

I put my pytorch version of the experiment in a quick Regularized gradient descent GAN optimization notebook. The core is very similar to what you and Yun Chen came up with.

I do have a specific problem in mind where I will try it on. Thanks again for bringing this up here.

Best regards

Thomas

2 Likes
from torch.autograd import grad
preds = Critic(interpolates)
    gradients = grad(outputs=preds, inputs=interpolates,
                     grad_outputs=torch.ones_like(preds).cuda(),
                     retain_graph=True, create_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0),  -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1)**2).mean()
2 Likes