Suppose I want to instantiate a Module with a function:
class Foo(nn.Module):
def __init__(self, fn):
super(Foo, self).__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
This works fine
def square(x):
return x**2
foonet = Foo(square)
But attempting to script the module fails
foonet_script = torch.jit.script(foonet)
giving
RuntimeError:
module has no attribute 'fn':
at <ipython-input-25-24af853ed724>:7:15
def forward(self, x):
return self.fn(x)
~~~~~~~ <--- HERE
It doesn’t help if I first script the function
square_script = torch.jit.script(square)
and then instantiate with the resulting torch._C.Function
. On the other hand, if I pass in another nn.Module
, everything is fine. Setting up a whole new Module when all I want is the function call seems overkill though.