I’m converting my distributed code to work with PyTorch’s AMP, but I have confusion on how I should compute my grad penalty as part of my WGAN-GP loss. I have referred to the documentation on this matter but I believe my use case is slightly different.
Here’s my code on computing the grad penalty:
out = discrim(sample)
gradients = torch.autograd.grad(inputs=sample, outputs=out,
grad_outputs=torch.ones(out.shape).to(sample.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
You may assume that sample is a tensor yielded by my generator and out is a N x 1
tensor. I’ll start off with saying that this code works without amp, but I’m just confused on how I’m suppose to use amp with this. This is what I’m currently trying:
with autocast(False):
sample = scaler.scale(sample)
out = discrim(sample)
gradients = torch.autograd.grad(inputs=sample, outputs=out,
grad_outputs=torch.ones(out.shape).to(sample.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients / scaler.get_scale()
Note that sample
was computed under autocast(True)
, but I’m not sure if this matters. This seems to work, but I think I want to compute out
under the mixed precision context to be faster. This seems to be slow atm.
This my second attempt at the problem:
with autocast(True):
out = discrim(x=sample)
with autocast(False):
sample = scaler.scale(sample)
out = scaler.scale(out)
gradients = torch.autograd.grad(inputs=sample, outputs=out,
grad_outputs=torch.ones(out.shape).to(sample.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients / scaler.get_scale()
But I get this error on autograd: One of the differentiated Tensors appears to not have been used in the graph
. I’m not really sure on what’s the proper way of doing this. Any help is greatly appreciated!