Accessing parameter names of jit traced autograd.Function's

Hello, I’ve been trying to convert a model with a custom autograd.Function to coreml.

class Custom(autograd.Function):
    def forward(ctx, input1, input2, param1, param2, param3):

    def backward(ctx, grad_output):

When traced, invocations of Custom.apply result in nodes of kind prim::PythonOp, with input1, input2 as node.inputs(). Different from other types of ops, the scalar parameters param1, param2, param3 do not seem to be registered as prim::Constant nodes in the graph (as far as I can see).

I was hoping to monkey-patch InternalTorchIRNode from coremltools and handle these parameters separately. Specifically, my thinking was to construct constant-kinded InternalTorchIRNodes as if torch.jit.trace had produced the corresponding prim::Constant nodes, and add them to the InternalTorchIRGraph instance myself, and wiring everything else together.

This is seems pretty doable: I’m readily able to access the values of these parameters via node.scalar_args(). I also wanted to access the names of the parameters (in order to be able to label my coreml nodes), and here’s where things got tricky: I can get a reference to the Custom.apply function via node.pyobj(), however it’s not possible to get a reference to Custom itself from here (getting such a reference would allow me to inspect the signature of Custom.forward, which is what I’m after). node.pyname() is not helpful either, since it’s not possible in general to unambiguously discover the autograd function by the name 'Custom' alone.

Looking around for a solution, the following caught my eye:

It looks like autogradFunction method from ConcretePythonOp is exactly what I’m looking for. It doesn’t seem to be exposed to the Python side, however, which brings me to the question here. Is there a way to access these parameter names, or could autogradFunction be exposed in torch._C.Node?

I don’t really understand everything that’s involved – possibly what I’m asking for is not technically possible – but I thought I’d ask here anyways, thanks

Silly me, all I needed was to read on to the implementation :slight_smile:
(TIL: the __self__ attribute)

from inspect import signature
params = list(signature(node.pyobj().__self__.forward).parameters)

There it is. Hope it’s useful to someone.