checkpoint can reduced the memory usage in pytorch. However, in libtorch there is no such function yet. Here I am trying to realize such a function with torch::autograd::backward. My code can be successfully compiled. When I run it, I meet some problem. The code is attached below.
#include <torch/torch.h>
#include
using namespace torch::autograd;
class TestFunction : public Function
{
public:
inline torch::Tensor Testf(torch::Tensor x, torch::Tensor y, const double &coeff)
{
torch::Tensor z = torch::zeros(x.sizes());
z = (3 * x + y) * coeff;
return z;
}
static variable_list forward(AutogradContext *ctx, Variable x, Variable y, const double &coeff)
{
std::cout << " forward" << std::endl;
ctx->saved_data["coeff"] = coeff;
ctx->save_for_backward({x, y});
torch::NoGradGuard no_guard;
TestFunction ts;
auto z = ts.Testf(x, y, coeff);
return {z};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output)
{
std::cout << " backward" << std::endl;
double coeff = ctx->saved_data["coeff"].toDouble();
auto v = ctx->get_saved_variables();
Variable x = v[0];
Variable y = v[1];
std::cout << grad_output[0] << std::endl;
Variable x1 = x.detach();
Variable y1 = y.detach();
x1.set_requires_grad(x.requires_grad());
y1.set_requires_grad(y.requires_grad());
torch::Tensor z;
TestFunction ts;
z = ts.Testf(x1, y1, coeff);
torch::AutoGradMode enable_grad(true);
torch::autograd::backward({z}, grad_output);
return {x1.grad(), y1.grad(), torch::Tensor()};
}
};
std::vectorat::Tensor TT(torch::Tensor x, torch::Tensor y, const double &coeff)
{
return TestFunction::apply(x, y, coeff);
}
auto main() → int
{
torch::manual_seed(3);
Variable x = Variable(torch::rand({2, 2})).set_requires_grad(true);
std::cout << "x=" << x << std::endl;
Variable y = Variable(torch::rand({2, 2})).set_requires_grad(true);
std::cout << "y=" << y << std::endl;
Variable v = torch::tensordot(x, y, {1}, {0});
Variable u = torch::exp(y);
std::vector<at::Tensor> z = TT(u, v, 5.0);
z[0].sum().backward();
std::cout << "x.grad=" << x.grad() << std::endl;
std::cout << "y.grad=" << y.grad() << std::endl;
return 0;
}