C++ autograd for torch::jit::script::Module

Hello, everybody!

I’m trying to evaluate the gradients of the model in C++ code using torch/script.

The original model was constructed in Python side with

#!/usr/bin/python3

import sys
import torch
import torch.nn as nn

class MLP(nn.Sequential):
  def __init__(self, layers=(7, 1)):
    modules = []
    for idx in range(len(layers) - 1):
      modules.append(nn.Linear(layers[idx], layers[idx + 1]))
      modules.append(nn.ELU())
    super().__init__(*modules[:-1])

Then I created a model and saved it to file:

model = MLP()
a = torch.ones((7), requires_grad=True)
traced_model_gen = torch.jit.trace(model, a)
traced_model_gen.save("traced_jit_model.pt")

Then, I loaded the model and evaluated its gradients:

jit_model = torch.jit.load("traced_jit_model.pt")
res1 = jit_model.forward(a)
print(res1)
grad1 = torch.autograd.grad(res1, a, create_graph=True, only_inputs=True)
print(grad1)

And it works fine.

Then I tried to load this model to C++ code and also evaluate the derivatives:

#include <iostream>
#include <torch/script.h>
#include <torch/csrc/autograd/autograd.h>
int main()
{
  torch::jit::script::Module module = torch::jit::load("traced_jit_model.pt");
  static float point[] = { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 };
  at::Tensor input = torch::from_blob(point, {7}, torch::TensorOptions().dtype(torch::kFloat));
  std::vector<torch::jit::IValue> inputs;
  inputs.push_back(input);
  auto outputs = module.forward(inputs);
  // the next line produces the same value like res1 in Python code
  torch::Tensor res1 = outputs.toTensor();
  std::cout << res1[0].item<float>(); << std::endl;
  auto grad1 = torch::autograd::grad(outputs, inputs, true);
}

I’m trying to evaluate gradients in the last line, but, unfortunately, compilation says about type inconsistency:

basic_torch.cpp:76:37: error: invalid initialization of reference of type ‘const variable_list&’ {aka ‘const std::vector<at::Tensor>&’} from expression of type ‘c10::IValue’
   76 |   auto grad = torch::autograd::grad(outputs, inputs, true);
      |                                     ^~~~~~~
In file included frombasic_torch.cpp:3:
libtorch/include/torch/csrc/autograd/autograd.h:74:26: note: in passing argument 1 of ‘torch::autograd::variable_list torch::autograd::grad(const variable_list&, const variable_list&, const variable_list&, c10::optional<bool>, bool, bool)’
   74 |     const variable_list& outputs,
      |     ~~~~~~~~~~~~~~~~~~~~~^~~~~~~
^Cmake[2]: *** [CMakeFiles/basic_torch.dir/build.make:63: CMakeFiles/basic_torch.dir/basic_torch.cpp.o] Interrupt

I looked at Autograd in C++ Frontend — PyTorch Tutorials 1.10.0+cu102 documentation for understanding how to evaluate the gradients of the model, but I did not find any information.

Are there any manuals on how to evaluate model gradients with autograd?

UPD:
libtorch version: 1.11.0.dev20211028+cpu

The C++ API doesn’t have all the “pass a tensor and it’ll be treated as a 1-element list of tensors” functions that PyTorch has in Python.
In other words, use torch::autograd::grad({outputs}, {inputs}, /* grad_outptus=*/ {}, /* create_graph=*/ true); or so, as seen in the higher order gradient example for C++.

Best regards

Thomas

Thank you, Thomas!

I successfully translated code from Python to C++!

There is the answer:

torch::jit::script::Module module = torch::jit::load("traced_jit_model.pt");
auto input1 = torch::ones({7}).requires_grad_(true);
std::vector<torch::jit::IValue> inputs;
inputs.push_back(input1);
auto output1 = module(inputs);
auto grad_output1 = torch::ones_like(output1.toTensor());
auto gradient1 = torch::autograd::grad({output1.toTensor()}, {input1}, /*grad_outputs=*/{grad_output1}, /*create_graph=*/true);
std::cout << gradient1 << std::endl;

It was quite simplier than I thought!

Best regards,
Igor