ONNX custom operator runtime error

I have a simple custom operator that inherits from torch.autograd.Function.

import torch.onnx
import torchvision

from torch import nn
from torch.autograd import Function

class MyReLUFunction(Function):
    @staticmethod
    def symbolic(g, input):
        return g.op('MyReLU', input)

    @staticmethod
    def forward(ctx, input):
        ctx.input = ctx
        return input.clamp(0)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        grad_input.masked_fill_(ctx.input < 0, 0)
        return grad_input

class MyReLU(nn.Module):
    def forward(self, input):
        return MyReLUFunction.apply(input)

model = nn.Sequential(
        nn.Conv2d(1, 1, 3),
        MyReLU(),
        )
dummy_input = torch.randn(10, 1, 3, 3)
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)

Following the instruction from the documentation, I added symbolic(g, input) for within the class, but when I conduct onnx export, it is still giving me

RuntimeError: No Op registered for MyReLU with domain_version of 9

So I was wondering what is the correct way to register a customized function when exporting to onnx.

1 Like

Is this error solved?

I had the same issue. Then found out that the given example will work in pytorch 1.1, 1.2 but it is throwing erros with the latest version like pytorch 1.5 and 1.7. When I digged into the pytorch onnx documentation I found that we have to pass as an additional argument operator_export_type to torch.onnx.export function(ref: https://pytorch.org/docs/stable/onnx.html#functions).

In my understanding of the documentation in order to support the custom lyaers with symbolic links, the additional argument operator_export_type should be assigned a value of OperatorExportTypes.ONNX_ATEN_FALLBACK.

So your line

torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)

Should be replaced with

torch.onnx.export(model, dummy_input, "model.onnx", verbose=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

Then this error should be solved.

1 Like