The outputs of grad function cannot back propagate


(Kun Wang) #1

I try to calculate gradient penalty. But It seems that the gradients cannot back propagate to the parameters of netD.

        output = netD(fuse)
        gradients = grad(output, fuse, only_inputs=True, create_graph=False, retain_graph=False, grad_outputs=torch.ones(output.size(), device=device))[0]
        gradients.requires_grad = True
        D_penalty = penalty_scale * (gradients.view(batch_size, -1).norm(2, dim=-1) - 1).pow(2).mean()
        D_penalty.backward()

(Alban D) #2

Hi,

You need to call grad with create_graph=True, not set the requires_grad by hand.


(Kun Wang) #3

I have changed my code. But it does not work too.

       fuse_factor = torch.rand(batch_size, 1, 1, 1, device=device)
        fuse = (fuse_factor * real + (1 - fuse_factor) * fake).requires_grad_(True)
        output = netD(fuse)
        gd_outputs = torch.ones(batch_size, requires_grad=False, device=device)

        gradients = autograd.grad(
            outputs=output,
            inputs=fuse,
            grad_outputs=gd_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]

        D_penalty = penalty_scale * (gradients.view(batch_size, -1).norm(2, dim=-1) - 1).pow(2).mean()
        D_penalty.backward()

(Alban D) #4

this code should propagate gradients properly to the parameters of netD and fuse. You can check that their .grad field is populated after the call to backward.