JIT compile and load custom operations

Hi, I tried to build my custom operation and the compilation worked. When I load my ops, I got the below error. The error even happened when I load the example in documentation. Could you please help me out? Thank you so much!

Load and get error
print(torch.ops.my_ops.SpikeFunction)
Traceback (most recent call last):
File “”, line 1, in
File “/calc/guozhang/anaconda3/lib/python3.8/site-packages/torch/_ops.py”, line 61, in getattr
op = torch._C._jit_get_operation(qualified_op_name)
RuntimeError: No such operator my_ops::SpikeFunction

The code I used to load the module:

torch.utils.cpp_extension.load(
… name=“SpikeFunction”,
… sources=[“spikefunction.cpp”],
… is_python_module=False,
… verbose=True
… )
Using /home/guozhang/.cache/torch_extensions as PyTorch extensions root…
Emitting ninja build file /home/guozhang/.cache/torch_extensions/SpikeFunction/build.ninja…
Building extension module SpikeFunction…
Allowing ninja to set a default number of workers… (overridable by setting the environment variable MAX_JOBS=N)
[1/2] c++ -MMD -MF spikefunction.o.d -DTORCH_EXTENSION_NAME=SpikeFunction -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -isystem /calc/guozhang/anaconda3/lib/python3.8/site-packages/torch/include -isystem /calc/guozhang/anaconda3/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /calc/guozhang/anaconda3/lib/python3.8/site-packages/torch/include/TH -isystem /calc/guozhang/anaconda3/lib/python3.8/site-packages/torch/include/THC -isystem /calc/guozhang/anaconda3/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -c /home/guozhang/multiplicative_rule/eprop/spikefunction.cpp -o spikefunction.o
[2/2] c++ spikefunction.o -shared -L/calc/guozhang/anaconda3/lib/python3.8/site-packages/torch/lib -lc10 -ltorch_cpu -ltorch -ltorch_python -o SpikeFunction.so
Loading extension module SpikeFunction…

The C++ code I wrote:
#include <torch/all.h>
#include <torch/python.h>

class SpikeFunction : public torch::autograd::Function {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable v_scaled,
torch::autograd::Variable dampening_factor) {
// forward calculation
ctx->save_for_backward({v_scaled,dampening_factor});
return {torch::greater(v_scaled, 0.)};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
// backward calculation
auto saved = ctx->get_saved_variables();
auto v_scaled = saved[0];
auto dampening_factor = saved[1];
auto dE_dz = grad_output[0];
auto dz_dv_scaled = torch::maximum(1 - torch::abs(v_scaled), torch::zeros_like(v_scaled)) * dampening_factor;
auto dE_dv_scaled = dE_dz * dz_dv_scaled;
return {dE_dv_scaled, torch::zeros_like(dampening_factor)};
}
};

torch::autograd::variable_list SpikeFunction(const torch::Tensor& v_scaled, const torch::Tensor& dampening_factor) {
return SpikeFunction::apply(v_scaled,dampening_factor);
}

seems that you’re not registering the operator:

static auto registry =
torch::RegisterOperators().op("myops::SpikeFunction", &SpikeFunction);

Thank you for your suggestion. Although I don’t why it did work in my case, I checked the keyword “register” and found I should add the below code at the end. It works.

TORCH_LIBRARY(my_ops, m) {
m.def(“SpikeFunction”, &SpikeFunction);
}

But I have a following question; the returned tensor’s requires_grad is false and I cannot backward it properly. Initially, I thought it may be due to the logical format then I tried a square operation but the issue emerged.

a=torch.rand(5,requires_grad=True)
b=torch.ops.my_ops.SpikeFunction(a,torch.tensor(0.3))
b[0]
tensor([True, True, True, True, True])

b[0].requires_grad
False

By the way, may I ask how to cast tensor from logic to float in c++? I checked the doc but I don’t know how to use it. I tried torch::greater(v_scaled, 0.)._cast_Float() , torch::greater(v_scaled, 0.) .at::_cast_Float() , at ::_cast_Float( torch::greater(v_scaled, 0.)) and many other patterns, all do not work.

Thank you!

I think that’s another registration type, you need operator registry for autograd to handle a function, as in my snippet.

By the way, may I ask how to cast tensor from logic to float in c++?

python ops are normally duplicated in c++, in this case at::Tensor::to(ScalarType) overload would do that.

Thank you. I did what you suggested but it still does not have grad_fun in the output tensor. Could you please try the below c++ code on your side?

class SpikeFunction : public torch::autograd::Function<SpikeFunction> {
 public:
  static torch::Tensor forward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::Variable v_scaled,
      torch::Tensor dampening_factor) {
    // forward calculation
    ctx->save_for_backward({v_scaled,dampening_factor});
    return (v_scaled > 0).type_as(v_scaled);
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::variable_list grad_output) {
    // backward calculation
    auto saved = ctx->get_saved_variables();
    auto v_scaled = saved[0];
    auto dampening_factor = saved[1];
    auto dE_dz = grad_output[0];
    auto dz_dv_scaled = torch::maximum(1 - torch::abs(v_scaled), torch::zeros_like(v_scaled)) * dampening_factor;
    auto dE_dv_scaled = dE_dz * dz_dv_scaled;
    return {dE_dv_scaled, torch::Tensor()};
  }
};

torch::Tensor SpikeFunction(torch::Tensor& v_scaled, torch::Tensor& dampening_factor) {
  return SpikeFunction::apply(v_scaled,dampening_factor);
}

// TORCH_LIBRARY(my_ops, m) {
//   m.def("SpikeFunction", &SpikeFunction);
// }

static auto registry = torch::RegisterOperators().op("myops::SpikeFunction", &SpikeFunction);

worked for me on 1.8 build:

torch.ops.myops.SpikeFunction(torch.ones(1).requires_grad_(), torch.ones(1).requires_grad_())
Out[10]: tensor(cpu,(1,)[1.], grad_fn=<CppNode>)

Note that at least one of tensors must have requires_grad, otherwise backward pass is omitted and grad_fn is not set.

Got it. It works! Thank you so much! Have a great weekend!