hi, I’m relatively new to torch script, and I’m seeing an effect where the grad_fn is not getting propagated, depending on where the scripted function is defined. Here’s the simplified example:
@torch.jit.script
def func0(x):
return x + x + x
class SomeModule(torch.nn.Module):
def forward(self, inp):
# inp: has grad_fn
@torch.jit.script
def func1(x):
return x + x + x
out0 = func0(inp) # grad_fn not set
out1 = func1(inp) # grad_fn set
This only represents the general idea (I’m working with a much larger system), but I’m seeing that with the same function definition, grad_fn gets set when the function is defined in the local scope, but not in the global scope. Even more odd, if func0 returns only x + x instead of x + x + x, then grad_fn is set as expected.
Any ideas what could be causing these differences? Thanks! (currently using Pytorch 1.11, but same effect since at least 1.8.)
@torch.jit.script
def func0(x):
return x + x + x
class SomeModule(torch.nn.Module):
def forward(self, inp):
# inp: has grad_fn
@torch.jit.script
def func1(x):
return x + x + x
out0 = func0(inp) # grad_fn not set
out1 = func1(inp) # grad_fn set
return out0, out1
model = SomeModule()
inp = torch.randn(1, requires_grad=True) + 1
print(inp.grad_fn)
# > <AddBackward0 object at 0x7f63c5c76940>
out0, out1 = model(inp)
print(out0.grad_fn)
# > <CppFunction at 0x7f63c5c76c10>
print(out1.grad_fn)
# > <AddBackward0 at 0x7f63c5c76dc0>
in 1.11.0.dev20211101. Could you update to the latest nightly binary and rerun the code?
Thanks for your responses. I probably wasn’t clear, that this isn’t the actual program I’m running (this simplified version works for me also), but rather I was using it to illustrate that there’s a difference between whether the scripted function is defined in the global scope versus local scope. Adding on to this, it doesn’t work either if the function is defined inside def __init__() and then set with self.func = func, and then later called in forward … the only case that works is when the scripted function is defined within forward.
One other detail I noticed is that the global case works if it only performs a single operation (e.g., x + x), rather than 2+ operations (e.g., x + x + x), which leads me to believe this might be a bug in the fuser.