How to combine train forward function and torchscript forward function?

I have a custom nn.Module which has custom c++ forward and backward function. I bind these function with pytorch extension:

  m.def("my_forward_cpu", &my_forward_cpu);
  m.def("my_backward_cpu", &my_backward_cpu);

And now, I also want to use my_forward_cpu for inference using torchscript. But it seems that I need to write

TORCH_LIBRARY(my_ops, m) {
  m.def("my_forward_cpu", my_forward_cpu);

and insert

import torch
import myop.backend

class MyModule(nn.Module):
def forward(self, x):
    if torch.jit.tracing():
        return torch.ops.my_ops.my_forward_cpu(x)
        return myop.backend.my_forward_cpu(x)

Is this the way it should be? Or can I combine the my_forward_cpu for train and for torchscript?

1 Like

met the same problem, is any update?