Backward keep graph cpp api

(Martin Huber) #1

I don’t understand how to use the C++ API of Pytorch to retain the graph.

How can one achieve loss.backward(retain_graph=True) with it?

inline void Tensor::backward(
    c10::optional<Tensor> gradient,
    bool keep_graph,
    bool create_graph) {
  type().backward(*this, std::move(gradient), keep_graph, create_graph);

This is how the backward pass is implemented, but how can one access the keep_graph variable?

(Martin Huber) #2

Here is an example of code that will crash. It will raise the error

the derivative for 'target' is not implemented

My main

#include <torch/torch.h>
#include <iostream>

int main() {

	// Layer.
	torch::nn::Linear linear1(5, 1);
	torch::nn::Linear linear2(5, 1);

	// Optimizer.
	torch::optim::Adam opt(linear1->parameters(), torch::optim::AdamOptions(0.001));

	// Input.
	torch::Tensor in = torch::randn({1, 5});

	// Output. Have one of these tensors. The first one wont work.
	torch::Tensor desired_out = linear2->forward(in); // raises the error "the derivative for 'target' is not implemented"
	//torch::Tensor desired_out = torch::ones({1, 1}); // works perfectly fine

	// Training loop.
	std::cout << "initial: " << linear1->forward(in) << std::endl;

	for (int i = 0; i < 1000; i++) {

		torch::Tensor out = linear1->forward(in);
		torch::Tensor loss = torch::mse_loss(out, desired_out);

	std::cout << "desired: " << desired_out << std::endl;
	std::cout << "trained: " << linear1->forward(in) << std::endl;

	return 0;

My Cmake

cmake_minimum_required(VERSION 3.11 FATAL_ERROR)

find_package(Torch REQUIRED)


add_executable(main main.cpp)
target_link_libraries(main ${TORCH_LIBRARIES})

(Thomas V) #3

You don’t access it, but pass it. But you need to pass torch::nullopt for the gradient:

loss.backward(torch::nullopt, /*keep_graph=*/ True, /*create_graph=*/ False);

In your example, a better solution would be to detach_() the desired_out, though.

Best regards


(Martin Huber) #4

Hi Thomas and thanks for the reply, thats indeed what I have already tried, and it does not work for my system. I have now tried the latest binaries, so to me it rather looks like it was a bug.

The detach_() approach works ! :slight_smile: So for this simple example it is

torch::Tensor desired_out = linear2->forward(in).detach();

(Thomas V) #5

Yeah, the C++ abi is still in flux (I think there are nightly builds of it these days, though).
Glad this works for you!

Best regards