Hello, I’ve been trying to convert a model with a custom autograd.Function
to coreml.
class Custom(autograd.Function):
@staticmethod
def forward(ctx, input1, input2, param1, param2, param3):
...
@staticmethod
def backward(ctx, grad_output):
...
When traced, invocations of Custom.apply
result in node
s of kind prim::PythonOp
, with input1
, input2
as node.inputs()
. Different from other types of ops, the scalar parameters param1
, param2
, param3
do not seem to be registered as prim::Constant
nodes in the graph (as far as I can see).
I was hoping to monkey-patch InternalTorchIRNode
from coremltools and handle these parameters separately. Specifically, my thinking was to construct constant
-kinded InternalTorchIRNode
s as if torch.jit.trace
had produced the corresponding prim::Constant
nodes, and add them to the InternalTorchIRGraph
instance myself, and wiring everything else together.
This is seems pretty doable: I’m readily able to access the values of these parameters via node.scalar_args()
. I also wanted to access the names of the parameters (in order to be able to label my coreml nodes), and here’s where things got tricky: I can get a reference to the Custom.apply
function via node.pyobj()
, however it’s not possible to get a reference to Custom
itself from here (getting such a reference would allow me to inspect the signature of Custom.forward
, which is what I’m after). node.pyname()
is not helpful either, since it’s not possible in general to unambiguously discover the autograd function by the name 'Custom'
alone.
Looking around for a solution, the following caught my eye:
It looks like autogradFunction
method from ConcretePythonOp
is exactly what I’m looking for. It doesn’t seem to be exposed to the Python side, however, which brings me to the question here. Is there a way to access these parameter names, or could autogradFunction
be exposed in torch._C.Node
?
I don’t really understand everything that’s involved – possibly what I’m asking for is not technically possible – but I thought I’d ask here anyways, thanks