How to trace external modules for model serialization

Hi. I wrote this question tagged as quantization, but the problem seems to be in how to save the model using torch.jit.save while containing external modules. Is this possible? I really appreciate any help you can provide.

There are two parts:

  • The official answer is that what you can do is to provide a custom operator in C++ (like eg torchvision does for eg nms) and then use that through torch.ops.mymodule.opname. This is compatible with the JIT. Including saving and loading.

  • The JIT has a Python fallback (if you tag a function @torch.jit.ignore and call that from your JITed function. This will let you trace a model, but you won’t be able to save it.

  • You could register a “stub” op and reflect that back to Python. Or write a little surgery helper to replace the Python fallbacks with that stub op before saving and change it back after loading.

Best regards

Thomas