I hit some problems when writing my C++ version of, for example, Concat, which takes some inputs and produce an output. Because Concat consumes varialble-length inputs, I use torch::autograd::variable_list
as its input type. Here is a simplified version of my operator.
struct MyOperator : public torch::autograd::Function<MyOperator>
{
static torch::autograd::variable_list forward(torch::autograd::AutogradContext *ctx, torch::autograd::variable_list input)
{
auto& x = input[0];
auto& y = input[1];
ctx->save_for_backward({x, y});
// This output's "requires_grad()" returns true.
torch::autograd::variable_list output = {x * y};
return output;
}
static torch::autograd::variable_list backward(torch::autograd::AutogradContext *ctx, torch::autograd::variable_list grad_output)
{
auto saved = ctx->get_saved_variables();
auto& x = saved[0];
auto& y = saved[1];
auto& dz = grad_output[0];
torch::autograd::variable_list output = {y * dz, x * dz};
return output;
}
}
The following code may throw because loss.requires_grad() returns false.
torch::autograd::Variable x_ = torch::randn({5,5}, torch::requires_grad());
torch::autograd::Variable y_ = torch::randn({5,5}, torch::requires_grad());
torch::autograd::variable_list args_;
args_.push_back(x_);
args_.push_back(y_);
auto res = MyEngine::apply(args_);
auto go = torch::ones({}, torch::requires_grad());
auto loss = res.at(0).sum();
// This line throws!
loss.backward();
Did I do something wrong?