Jit.scripting nn.Module with function attributes

Suppose I want to instantiate a Module with a function:

class Foo(nn.Module):
    def __init__(self, fn):
        super(Foo, self).__init__()
        self.fn = fn
        
    def forward(self, x):
        return self.fn(x)

This works fine

def square(x):
    return x**2

foonet = Foo(square)

But attempting to script the module fails

foonet_script = torch.jit.script(foonet)

giving

RuntimeError: 
module has no attribute 'fn':
at <ipython-input-25-24af853ed724>:7:15
    def forward(self, x):
        return self.fn(x)
               ~~~~~~~ <--- HERE

It doesn’t help if I first script the function

square_script = torch.jit.script(square)

and then instantiate with the resulting torch._C.Function. On the other hand, if I pass in another nn.Module, everything is fine. Setting up a whole new Module when all I want is the function call seems overkill though.

Update: It’s a known bug see this issue for function attributes and this post for class attributes.

1 Like

This bug should be fixed in https://github.com/pytorch/pytorch/pull/28569, you should be able to try out the fix if you build master from source or use the nightly PyTorch pip package

1 Like