Location of scripted function's definition affects result

hi, I’m relatively new to torch script, and I’m seeing an effect where the grad_fn is not getting propagated, depending on where the scripted function is defined. Here’s the simplified example:

@torch.jit.script
def func0(x):
    return x + x + x

class SomeModule(torch.nn.Module):

    def forward(self, inp):
        # inp: has grad_fn

        @torch.jit.script
        def func1(x):
            return x + x + x

        out0 = func0(inp) # grad_fn not set
        out1 = func1(inp) # grad_fn set

This only represents the general idea (I’m working with a much larger system), but I’m seeing that with the same function definition, grad_fn gets set when the function is defined in the local scope, but not in the global scope. Even more odd, if func0 returns only x + x instead of x + x + x, then grad_fn is set as expected.

Any ideas what could be causing these differences? Thanks! (currently using Pytorch 1.11, but same effect since at least 1.8.)

I’m seeing grad_fns for both cases:

@torch.jit.script
def func0(x):
    return x + x + x

class SomeModule(torch.nn.Module):

    def forward(self, inp):
        # inp: has grad_fn

        @torch.jit.script
        def func1(x):
            return x + x + x

        out0 = func0(inp) # grad_fn not set
        out1 = func1(inp) # grad_fn set
        
        return out0, out1
    
model = SomeModule()
inp = torch.randn(1, requires_grad=True) + 1
print(inp.grad_fn)
# > <AddBackward0 object at 0x7f63c5c76940>
out0, out1 = model(inp)
print(out0.grad_fn)
# > <CppFunction at 0x7f63c5c76c10>
print(out1.grad_fn)
# > <AddBackward0 at 0x7f63c5c76dc0>

in 1.11.0.dev20211101. Could you update to the latest nightly binary and rerun the code?

I tried it in a slightly different manner and it works

import os
import numpy as np
import time
import torch
import torchvision
from torch import nn
from torch.autograd import Variable

def func0(x, y):
    return x + y

class SomeModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.y = nn.Parameter(torch.randn([1]))

    def forward(self, inp):
        def func1(x):
            return x + self.y

        out0 = func0(inp, self.y)
        out1 = func1(inp)
        return out0, out1
        
model = SomeModule()
inp = torch.randn(1) + 1
print(inp.grad_fn)
out0, out1 = model(inp)
print(out0.grad_fn)
print(out1.grad_fn)
torch.__version__
'1.8.1'

Thanks for your responses. I probably wasn’t clear, that this isn’t the actual program I’m running (this simplified version works for me also), but rather I was using it to illustrate that there’s a difference between whether the scripted function is defined in the global scope versus local scope. Adding on to this, it doesn’t work either if the function is defined inside def __init__() and then set with self.func = func, and then later called in forward … the only case that works is when the scripted function is defined within forward.

One other detail I noticed is that the global case works if it only performs a single operation (e.g., x + x), rather than 2+ operations (e.g., x + x + x), which leads me to believe this might be a bug in the fuser.

Also, I’m not sure, but this topic seems possibly related: Second forward call of torchscripted module breaks on cuda

thanks again!