TorchScript register backward C++ functions

I just read this tutorial and it does not touch the backward part.
https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html

My understanding is that to get autograd work, we will also need to register the corresponding backward function. However, I could not find how this works. I checked PyTorch source code (torch/csrc/autograd/generated/VariableTypeEverything.cpp) and could not find where backward functions are registered either.

Help will be appreciated!

You implement it in C++ similarly to in Python via autograd.Function. You then have to register an op that uses your autograd function. See this example

#include <torch/script.h>
#include <torch/all.h>

#include <iostream>
#include <memory>

using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

// computes f(x) = 2x
class MyDouble : public torch::autograd::Function<MyDouble> {
 public:
  static variable_list forward(
      AutogradContext* ctx,
      Variable input) {
    return {input + input};
  }

  static variable_list backward(
      AutogradContext* ctx,
      variable_list grad_output) {
    return {torch::ones({2, 2}) + 1};
  }
};

Tensor double_op(const Tensor& input) {
  return MyDouble::apply(input)[0];
}

static auto registry =
  torch::RegisterOperators("my_ops::double_op", &double_op);

and in Python

import torch

torch.ops.load_library("build/libmy_custom_op.so")

@torch.jit.script
def fn(x):
    return torch.ops.my_ops.double_op(x)

x = torch.randn(2, 2, requires_grad=True)
print(fn(x))
x.backward(torch.ones(2, 2))
3 Likes

Hi, saw this thread and thought i’d post an interesting issue - perhaps you could see whats wrong?

If you modify your python script to:

import torch

torch.ops.load_library("build/libmy_custom_op.so")

@torch.jit.script
def fn(x):
    xxx = torch.pow(x, 2)
    return xxx * torch.ops.my_ops.double_op(x)

x1 = torch.tensor([[0.3, 0.5], [2.0, 0.7]], requires_grad=True)
x2 = torch.tensor([[0.3, 0.5], [2.0, 0.7]], requires_grad=True)

print(fn(x1))
x1.backward(torch.ones(2, 2))

print (x1.grad)

torch_only = torch.sum(torch.pow(x2, 2) * 2 * x2)
print (torch_only)

torch_only.backward(torch.ones(2, 2))
print (x2.grad)

then you’ll notice that the gradients between the torch-only function and the composite torchscript + pytorch are different. The torch-only function has the correct gradients.