Jit.script and torch.no_grad(): Different behavior with and without JIT?

Hi,
can someone explain if this behavior is intended and if so, why and what the solution would be to get the same behavior with scripting as without:

import torch


@torch.jit.script
def scripted_bar(x):
    with torch.no_grad():
        y = x * 2

    z = 2 * x + y
    print("scripted_bar", x.requires_grad, y.requires_grad, z.requires_grad)

    return z


def bar(x):
    with torch.no_grad():
        y = x * 2

    z = 2 * x + y
    print("bar", x.requires_grad, y.requires_grad, z.requires_grad)

    return z


a = torch.rand(3, requires_grad=True)
b = torch.rand(3, requires_grad=True)


print("scripted_bar(a)", scripted_bar, scripted_bar(a).requires_grad)
print("bar(a)", bar, bar(a).requires_grad)

Output:

scripted_bar True False False
scripted_bar(a) <torch.jit.ScriptFunction object at 0x7fa72315d450> False
bar True False True
bar(a) <function bar at 0x7fa797fae3a0> True

The z variable does have no gradient tracking with jit.script.

This behavior seems indeed unexpected. Could you create a GitHub issue with this minimal code snippet, please?

Thank you for answering :slight_smile: I will create an issue