Computing gradients with Apex

I want to compute gradients in the standard manner:

grad = torch.autograd.grad(loss, model.parameters()

I want to use AMP. To get the scaled loss, apex api suggests:

with amp.scale_loss(loss, optimizer) as scaled_loss:

But I don’t want to call loss.backward() directly. I compute the gradients in the manner I showed above
I tried this:

with amp.scale_loss(loss, optimizer) as scaled_loss:
    grad = torch.autograd.grad(scaled_lossloss, model.parameters()

This makes the loss nan. How to compute gradients with apex ?

We recommend to use the native automatic mixed-precision training using a master build or the nightly binaries.
For your use case, you might have forgotten to unscale the gradients or might have manually updated the parameters with invalid gradients.

Have a look at the gradient penalty example for an example use case of autograd.grad.

Thanks for your reply. I did what you said. Before I can test your suggestion, now I’m facing some other issue.
Without fp16, I’m able to train with batch size 128.
After amp.initialize() is called, I am getting an OOM error with the same batch size which is not ideal since RAM usage should decrease. A similar issue has been raised here.

The linked issue initialized models in a loop, which is not supported in apex/amp.
Could you post a reproducible code snippet using native amp, so that we could have a look?

Regarding first issue, as per your suggestion, I modified my code as:

if self.args.fp16:
    with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_grad_params = torch.autograd.grad(self.scaler.scale(scaled_loss), params, allow_unused=True)

Before I can multiply params by inv_scale, it throws an error because scaled_grad_params is None.
I assume that amp.scale_loss is not supposed to be used like this. I searched in the documentation, this method is used with loss.backward() which I’m not calling. Can you please suggest what am I missing ?

Is your code working without amp, i.e. are you getting the sum of gradients from autograd.grad?

Yes, my code works fine without AMP.
I compute grads:

grads = torch.autograd.grad(loss, model.parameters())

Then update my model’s parameters:

for param in model.parameters(): = ....  # (these params come from grads after some operation)

I don’t know, why the output is apparently empty when you are using amp, so we would need to debug it.
Could you post a minimal, executable code snippet so that we could have a look?