Is retain_grad() supported in new C++ API?

I want to use retain_grad to get the gradient for non-leaf variables in C++. Hook registration is implemented in But I didn’t find retain_grad() implementation in C++. Does it mean I need to implement it with registerFunctionHook()? Thanks!

Yes, but as you said, you would need to manually create a FunctionPreHook (defined here) and then use variable.add_hook (see here) to register it.

Thanks Simon will do that! Just curious do you have plans for having a retain_grad() explicitly in C++? as I imagine it could be a very useful helper hook…LMK if I can be of any help here.

I implemented a FunctionPreHook and use variable.add_hook to “register”, but at runtime the hook is not invoked at all. I checked the length of pre_hooks_ for Functions at runtime and they are all 0s.

This is where I checked length of pre_hooks_, inside engine.cpp

static variable_list call_pre_hooks(Function& fn, variable_list inputs) {
  std::cout << fn.pre_hooks().size() << std::endl;
  for (const auto& hook : fn.pre_hooks()) {
    inputs = (*hook)(inputs);
  return inputs;

I checked code under csrc/autograd but didn’t find where add_pre_hook() is called in cpp, or where hooks in variables are added to pre_hooks_ in Function.

Did I miss something? Thanks!

This is a minimal example, where “here” is expected to be printed out twice.

using namespace torch::autograd;
struct DummyHook : public FunctionPreHook {
  DummyHook(Variable* x) : v(x) {};
  variable_list operator()(const variable_list& grads) {
    std::cout << "here" << std::endl;
    return grads;
  Variable* v;

int main(int argc, char** argv) {
  torch::Tensor x = torch::tensor({1.0}, at::requires_grad()); // CPUDoubleType
  Variable y = x * x;
  y.add_hook(std::shared_ptr<DummyHook>(new DummyHook(&y)));
  torch::Tensor z = 2 * y;
  for (auto hook: y.hooks()) {

I think add_hook only works on the leaves. For intermediate variables you can try variable.grad_fn().add_pre_hook

1 Like

add_pre_hook in grad_fn works, thanks Simon!

1 Like