Torch.no_grad works for operations only

In the following code, I see, requires_grade is False for a**2 but True for a even after using torch.no_grad(). Why is a.requires_grad is still True?

print(a.requires_grad)
print((a**2).requires_grad)
with torch.no_grad():
  print(a.requires_grad)
  print((a**2).requires_grad)

Output:

True
True
True
False

Hi,

torch.no_grad() only changes the behavior of pytorch ops by forcing them not to track gradients.
So if you do no op, it has no effect.
It is unclear if we want to change that in the future (to make a.requires_grad print False here), but this is the current behavior.

1 Like

Is there any reason to keep it the same way?

The main one is that there is no consensus that this is the right way to do it.
Having t.requires_grad change based on a global flag that could be controlled from other code could be quite surprising to the user as well.
Also many people are now used to the current behavior so we want to have a solid argument if we change it.