You implement it in C++ similarly to in Python via autograd.Function
. You then have to register an op that uses your autograd function. See this example
#include <torch/script.h>
#include <torch/all.h>
#include <iostream>
#include <memory>
using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
// computes f(x) = 2x
class MyDouble : public torch::autograd::Function<MyDouble> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input) {
return {input + input};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {torch::ones({2, 2}) + 1};
}
};
Tensor double_op(const Tensor& input) {
return MyDouble::apply(input)[0];
}
static auto registry =
torch::RegisterOperators("my_ops::double_op", &double_op);
and in Python
import torch
torch.ops.load_library("build/libmy_custom_op.so")
@torch.jit.script
def fn(x):
return torch.ops.my_ops.double_op(x)
x = torch.randn(2, 2, requires_grad=True)
print(fn(x))
x.backward(torch.ones(2, 2))