Here is the piece of my code, I register the conv2d for the new key.
When dispatch into the corresponding function, How can I know it’s conv2d which is being dispatched?
template<class Redispatch, Redispatch* F, class Ret, class ArgList> struct CPU_WrapFunction_ {};
template<class Registered, class Redispatch, Redispatch* F>
struct WrapFunction_CPU final {
using type = CPU_WrapFunction_<Redispatch, F,
typename guts::function_traits<Registered>::return_type,
typename guts::function_traits<Registered>::parameter_types>;
};
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct CPU_WrapFunction_<Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
**//How to get the op's name here**
return (*F)(args...);
}
};
#define KERNEL_CPU(FUNC, REGISTER_NAME, SIGNATURE) \
m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction_CPU<SIGNATURE, SIGNATURE, &FUNC>::type::call);
TORCH_LIBRARY_IMPL(aten, mykey, m){
KERNEL_CPU(at::conv2d, "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t))
}