I was reading Improved Training of Wasserstein GANs, and thinking how it could be implemented in PyTorch. It seems not so complex but how to handle gradient penalty in loss troubles me.
In the tensorflow's implementation, the author use
I wonder if there is an easy way to handle the gradient penalty.
here is my idea of implementing it, I don't know whether it will work and work in the way I think:
optimizer_D = optim.adam(model_D.parameters())
x = Variable()
y = Variable()
x_hat = (alpha*x+(1-alpha)*y).detach()
x_hat.requires_grad = True
loss_D = model_D(x_hat).sum()
x_hat.grad.volatile = False
loss = model_D(x).sum() - model_D(y).sum() + ((x_hat.grad -1)**2 * LAMBDA).sum()