Scripting custom autograd function. RuntimeError: You attempted to access the anomaly metadata of a custom autograd function

import torch
import torch.nn as nn

class CustomFunction(torch.autograd.Function, nn.Module):
    def __call__(self, input):
        return self.apply(input)

    @staticmethod
    def forward(ctx, forward_in):
        forward_out = forward_in.clamp(min=0)
        ctx.save_for_backward(forward_in, forward_out)
        return forward_out

    @staticmethod
    def backward(ctx, grad_output):
        forward_in, forward_out = ctx.saved_tensors

        relu_gradients = torch.ones_like(forward_out)
        relu_gradients[forward_in < 0] = 0
        relu_gradients = relu_gradients.mul(grad_output)
        # Some extra functions here
        return relu_gradients


class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.cr = CustomFunction()

    def forward(self, x):
        x.requires_grad_(True)
        self.cr(x)
        criterion = x**2
        criterion.backward()
        return x.grad

model = CustomModel()
traced_script_module = torch.jit.script(model)
traced_script_module.save("traced_jit_model.pt")
jit_model = torch.jit.load('traced_jit_model.pt')

This is the minimal version of the code that causes the following error:

RuntimeError: You attempted to access the anomaly metadata of a custom autograd function but the underlying PyNode has already been deallocated.  The most likely reason this occurred is because you assigned x.grad_fn to a local variable and then let the original variable get deallocated.  Don't do that!  If you really have no way of restructuring your code so this is the case, please file an issue reporting that you are affected by this.

Is there a way to reconstruct the code so that this error isn’t thrown? I don’t understand what the error means.

I solved this as follows:

#include <torch/script.h>
#include <torch/all.h>

#include <iostream>
#include <memory>


class CustomReluOp : public torch::autograd::Function<CustomReluOp>{
    public:
        static torch::autograd::variable_list forward(torch::autograd::AutogradContext* ctx, torch::autograd::Variable forward_in){
            auto forward_out = torch::clamp(forward_in, 0.0);
            ctx->save_for_backward({forward_in, forward_out});
            return {forward_out};
        }

        static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output){
            auto list_forward = ctx->get_saved_variables();
            auto forward_in = list_forward[0];
            auto forward_out = list_forward[1];

            auto relu_gradients = torch::ones_like(forward_out);
            auto indices = forward_in<0;
            relu_gradients.index({indices}) = 0;
            relu_gradients = torch::mul(relu_gradients, grad_output[0]);

            relu_gradients = torch::nn::functional::relu(relu_gradients);
            relu_gradients = relu_gradients * forward_out;
            return {relu_gradients};
        }
};

torch::Tensor custom_relu_op(const torch::Tensor& input) {
  return CustomReluOp::apply(input)[0];
}

static auto registry = torch::RegisterOperators("my_ops::custom_relu_op", &custom_relu_op);

The code above I saved as customrelu.cpp, I then made the file CMakeLists.txt with the following instructions:

cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(custom_relu)

find_package(Torch REQUIRED)

# Define our library target
add_library(custom_relu SHARED customrelu.cpp)
set(CMAKE_CXX_STANDARD 14)
# Link against LibTorch
target_link_libraries(custom_relu "${TORCH_LIBRARIES}")

I then made an empty folder build and cd’d into it and ran the following:

cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..

make -j

The following code now uses a custom operation that you can script:

import torch
import torch.nn as nn

torch.ops.load_library("inference/build/libcustom_relu.so")

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.cr = torch.ops.my_ops.custom_relu_op

    def forward(self, x):
        x.requires_grad_(True)
        self.cr(x)
        criterion = x**2
        criterion.backward()
        return x.grad

model = CustomModel()
traced_script_module = torch.jit.script(model)
traced_script_module.save("traced_jit_model.pt")
jit_model = torch.jit.load('traced_jit_model.pt')

I added these instructions because the documentation for the c++ frontend is somewhat daunting for someone with little experience with c++ (like me). Credits to this post: TorchScript register backward C++ functions