How to trace a non-forward function

Hi, I have an model class which inherents nn.Module, but the main use case is a member function(not forward function, i hope to trace/script this model, but i have no idea how to do it. For example, the class is kind of like:

class A(nn.Module):
 def __init__(self):
    self. linear_layer = xxxx.
 def solve():
     pass
a = A()
a.solve()    

I want to jit the a.solve function using the A’s learnable parameters.

I think you would have to call self.solve() in the forward method to properly script the model.
In eager mode you could still use a.solve(), but if I’m not mistaken, the JIT expects the “standard” implementation using forward.