What is the correct way of computing a grad penalty using AMP?

I’m converting my distributed code to work with PyTorch’s AMP, but I have confusion on how I should compute my grad penalty as part of my WGAN-GP loss. I have referred to the documentation on this matter but I believe my use case is slightly different.

Here’s my code on computing the grad penalty:

out = discrim(sample)

gradients = torch.autograd.grad(inputs=sample, outputs=out,
    grad_outputs=torch.ones(out.shape).to(sample.device),
    create_graph=True, retain_graph=True, only_inputs=True)[0]

You may assume that sample is a tensor yielded by my generator and out is a N x 1 tensor. I’ll start off with saying that this code works without amp, but I’m just confused on how I’m suppose to use amp with this. This is what I’m currently trying:

with autocast(False):
    sample = scaler.scale(sample)
    out = discrim(sample)

    gradients = torch.autograd.grad(inputs=sample, outputs=out,
        grad_outputs=torch.ones(out.shape).to(sample.device),
        create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients / scaler.get_scale()

Note that sample was computed under autocast(True), but I’m not sure if this matters. This seems to work, but I think I want to compute out under the mixed precision context to be faster. This seems to be slow atm.

This my second attempt at the problem:

with autocast(True):
    out = discrim(x=sample)

    with autocast(False):
        sample = scaler.scale(sample)
        out = scaler.scale(out)

        gradients = torch.autograd.grad(inputs=sample, outputs=out,
            grad_outputs=torch.ones(out.shape).to(sample.device),
            create_graph=True, retain_graph=True, only_inputs=True)[0]

        gradients = gradients / scaler.get_scale()

But I get this error on autograd: One of the differentiated Tensors appears to not have been used in the graph. I’m not really sure on what’s the proper way of doing this. Any help is greatly appreciated!

1 Like

I don’t believe your use case is different from what’s shown in the docs, aside from more explicit kwargs and outputs being a Tensor with (presumably) more than one element. Perhaps, instead of

To implement a gradient penalty with gradient scaling, the loss passed to torch.autograd.grad() should be scaled.

the docs should say

To implement a gradient penalty with gradient scaling, the outputs Tensor(s) passed to torch.autograd.grad() should be scaled.

Try

with autocast():
    out = discrim(sample)

gradients = torch.autograd.grad(outputs=scaler.scale(out), inputs=sample,
    grad_outputs=torch.ones(out.shape, device=sample.device),
    create_graph=True, retain_graph=True, only_inputs=True)[0]

# proceed as shown in the docs

# if gradients is a Tensor
gradients = gradients / scaler.get_scale()
# if gradients is a tuple
inv_scale = 1./scaler.get_scale()
gradients = [p * inv_scale for p in gradients]

with autocast():
    # compute penalty term from gradients
    # add to loss if needed

scaler.scale(<loss or penalty>).backward()

If that works I will update docs as described above.

FYI in grad, retain_graph defaults to the value of create_graph and only_inputs defaults to True so these kwargs are not needed. Also, I changed
torch.ones(out.shape).to(sample.device) to torch.ones(out.shape, device=sample.device) which synthesizes grad_outputs directly on the device, avoids an expensive CPU->GPU copy, and is good practice regardless of Amp.

1 Like

Thank you for the reply!

I’ve implemented the suggestion you made and it all appears to be working. With your comment and cross-referencing with the docs, it makes much more sense on what’s happening now. Thank you also for the additional comments, I didn’t realize the to(...) had that problem and I made that change.

I do agree that if the docs listed it as the output tensors that this would make more sense for newcomers in the future.

2 Likes