Hi,
can someone explain if this behavior is intended and if so, why and what the solution would be to get the same behavior with scripting as without:
import torch
@torch.jit.script
def scripted_bar(x):
with torch.no_grad():
y = x * 2
z = 2 * x + y
print("scripted_bar", x.requires_grad, y.requires_grad, z.requires_grad)
return z
def bar(x):
with torch.no_grad():
y = x * 2
z = 2 * x + y
print("bar", x.requires_grad, y.requires_grad, z.requires_grad)
return z
a = torch.rand(3, requires_grad=True)
b = torch.rand(3, requires_grad=True)
print("scripted_bar(a)", scripted_bar, scripted_bar(a).requires_grad)
print("bar(a)", bar, bar(a).requires_grad)
Output:
scripted_bar True False False
scripted_bar(a) <torch.jit.ScriptFunction object at 0x7fa72315d450> False
bar True False True
bar(a) <function bar at 0x7fa797fae3a0> True
The z variable does have no gradient tracking with jit.script.