Coredump when registering a dispatched Operator in C++ with cuda backend only

I register an op in c++ with the reference of https://pytorch.org/tutorials/advanced/dispatcher.html#for-operators-that-do-not-need-autograd.I implement cuda device backend only and the autograd is also impl in c++. But I notice the autograd kernel needs to redispatch : it needs to call back into the dispatcher to get to the inference kernels, e.g. CPU or CUDA implementations.

I complie th c++ extension successfully. But when I import the new python module , coredump occures. I debug it with gdb and the stack information show as the following:


torch_softmax.cpp is the file defining the autograd kernel calling. part of its information just like this

#include "torch_softmax.h"

#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include <torch/types.h>

namespace at {
namespace native {

Tensor fuse_softmax_forward(const Tensor &input, const int64_t dim, const bool half_to_float) {
  // C10_LOG_API_USAGE_ONCE("lightop.csrc.torch_softmax.fuse_softmax_forward");
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("myops::fuse_softmax_forward", "")
                       .typed<decltype(fuse_softmax_forward)>();
  return op.call(input, dim, half_to_float);
}

namespace detail {

Tensor fuse_softmax_backward(const Tensor& grad, const Tensor& output, int64_t dim, const Tensor& input) {
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("myops::fuse_softmax_backward", "")
          .typed<decltype(fuse_softmax_backward)>();
  return op.call(
      grad,
      output,
      dim,
      input);
}

} // namespace detail

TORCH_LIBRARY_FRAGMENT(myops, m) {
  m.def("fuse_softmax_forward(const Tensor &input, const int64_t dim, const bool half_to_float) -> Tensor");

  m.def("fuse_softmax_backward(const Tensor& grad, const Tensor& output, int64_t dim, const Tensor& input) -> Tensor");

}

} // namespace ops
} // namespace vision

I want to keep the autograd module with c++ implementation to improve performance. The stack info shows the op calls in jit method. I suspect this problem is related to the dynamic dispatch selection. But I don’t know how to call the definite cuda backend or how to debug it furthermore. Could someone give me some advice?