TorchScript register backward C++ functions

You implement it in C++ similarly to in Python via autograd.Function. You then have to register an op that uses your autograd function. See this example

#include <torch/script.h>
#include <torch/all.h>

#include <iostream>
#include <memory>

using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

// computes f(x) = 2x
class MyDouble : public torch::autograd::Function<MyDouble> {
 public:
  static variable_list forward(
      AutogradContext* ctx,
      Variable input) {
    return {input + input};
  }

  static variable_list backward(
      AutogradContext* ctx,
      variable_list grad_output) {
    return {torch::ones({2, 2}) + 1};
  }
};

Tensor double_op(const Tensor& input) {
  return MyDouble::apply(input)[0];
}

static auto registry =
  torch::RegisterOperators("my_ops::double_op", &double_op);

and in Python

import torch

torch.ops.load_library("build/libmy_custom_op.so")

@torch.jit.script
def fn(x):
    return torch.ops.my_ops.double_op(x)

x = torch.randn(2, 2, requires_grad=True)
print(fn(x))
x.backward(torch.ones(2, 2))
3 Likes