How to combine train forward function and torchscript forward function?

Hi,
I have a custom nn.Module which has custom c++ forward and backward function. I bind these function with pytorch extension:

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("my_forward_cpu", &my_forward_cpu);
  m.def("my_backward_cpu", &my_backward_cpu);
}

And now, I also want to use my_forward_cpu for inference using torchscript. But it seems that I need to write

TORCH_LIBRARY(my_ops, m) {
  m.def("my_forward_cpu", my_forward_cpu);
}

and insert

import torch
import myop.backend
torch.ops.load_library("build/libmyfunc.so")

class MyModule(nn.Module):
...
def forward(self, x):
    if torch.jit.tracing():
        return torch.ops.my_ops.my_forward_cpu(x)
    else:
        return myop.backend.my_forward_cpu(x)

Is this the way it should be? Or can I combine the my_forward_cpu for train and for torchscript?

1 Like

met the same problem, is any update?