for some reason, I am using gradient descent as part of the forward function of my model.
Hence, I wonder if it would be possible to enforce gradient computations in all cases, even if there is some no_grad context manager that is activated ?
You can use torch.set_grad_enabled(True) to activate gradient tracking even under a no_grad context.
For instance:
import torch
x = torch.ones(5, requires_grad=True)
with torch.no_grad():
torch.set_grad_enabled(True)
y = x * x
y.sum().backward()
print(x.grad) # works fine
However, this does not work if you disabled gradients with the torch.inference_mode context:
x = torch.ones(5, requires_grad=True)
with torch.inference_mode():
torch.set_grad_enabled(True)
y = x * x
y.sum().backward() # fails