I have succeeded adding my custom primitives to PyTorch via RegisterOperators
and RegisterPass
from within PyBind11. Models from TorchVision that I load and transform using torch.jit.trace
definitely have them — I see their console output and the output tensor bears clear marks of their work.
However, when I try to torch.jit.save
the traced model and torch.jit.load
it back some other time, my prims are nowhere to be seen.
What am I doing wrong? How do I use my own primitives in a network loaded using torch.jit.load
?