Force gradients computation even in inference mode


for some reason, I am using gradient descent as part of the forward function of my model.

Hence, I wonder if it would be possible to enforce gradient computations in all cases, even if there is some no_grad context manager that is activated ?

thanks a lot

Hi Antoine,

You can use torch.set_grad_enabled(True) to activate gradient tracking even under a no_grad context.

For instance:

import torch

x = torch.ones(5, requires_grad=True)
with torch.no_grad():
  y = x * x

print(x.grad)  # works fine

However, this does not work if you disabled gradients with the torch.inference_mode context:

x = torch.ones(5, requires_grad=True)
with torch.inference_mode():
  y = x * x

y.sum().backward()  # fails

Hi Victor,
ok, great, thanks a lot for the answer.

just another question then:
is there a way to detect inference mode to issue some warning ?

Yes, you can detect if we torch.is_inference_mode_enabled()

1 Like