KunWangV
(Kun Wang)
December 31, 2018, 5:28am
1
I try to calculate gradient penalty. But It seems that the gradients cannot back propagate to the parameters of netD.
output = netD(fuse)
gradients = grad(output, fuse, only_inputs=True, create_graph=False, retain_graph=False, grad_outputs=torch.ones(output.size(), device=device))[0]
gradients.requires_grad = True
D_penalty = penalty_scale * (gradients.view(batch_size, -1).norm(2, dim=-1) - 1).pow(2).mean()
D_penalty.backward()
albanD
(Alban D)
January 1, 2019, 8:53am
2
Hi,
You need to call grad with create_graph=True, not set the requires_grad by hand.
KunWangV
(Kun Wang)
January 1, 2019, 8:57am
3
I have changed my code. But it does not work too.
fuse_factor = torch.rand(batch_size, 1, 1, 1, device=device)
fuse = (fuse_factor * real + (1 - fuse_factor) * fake).requires_grad_(True)
output = netD(fuse)
gd_outputs = torch.ones(batch_size, requires_grad=False, device=device)
gradients = autograd.grad(
outputs=output,
inputs=fuse,
grad_outputs=gd_outputs,
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
D_penalty = penalty_scale * (gradients.view(batch_size, -1).norm(2, dim=-1) - 1).pow(2).mean()
D_penalty.backward()
albanD
(Alban D)
January 7, 2019, 8:39am
4
this code should propagate gradients properly to the parameters of netD and fuse
. You can check that their .grad
field is populated after the call to backward.