[torch::autograd::Function] Any way to know whether forward is called with autograd on?

Assume I have a working torch::autograd::Function, and that for the forward method I actually have two ways of computing it, call them forward_1 and forward_2. They are such that if one were to do only inference, forward_1 would be faster, and if one intends to backprop then forward_2+backward is faster than forward_1+backward (basically forward_2 will pre-compute some quantities in anticipation of libtorch later calling backward which can then optimally build upon those quantities to compute the needed gradients efficiently).

So ideally, I’d love to know if there’s a way that could tell me inside torch::autograd::Function::forward whether this method is being called with autograd on or off (ie with or without a no_grad guard), so that I could do something like:

if(autograd_is_on){
  forward_2(...);
  ctx->saved_data["..."] = ...; // save the pre-computed stuff that backward will need
} else {
  forward_1(...);
}

I think torch::GradMode::is_enabled() should work as an internal check.

2 Likes

Thank you! I will try this.

It doesn’t seem to work :-(. I get a runtime error where torch complains that he’s expecting a Tensor but finds a None instead in ctx->saved_data in backward during training. When I replace the above pseudo-code with the following the code runs fine:

forward_2(...);
ctx->saved_data["..."] = ...; // save the pre-computed stuff that backward will need

(ie I don’t check whether autograd is on now and I systematically compute what’s needed for backward)

Looks like Gradmode::is_enabled() doesn’t carry the correct value when inside a torch::autograd::Function?

I’m don’t fully understand the current issue, as it seems using the condition raises the runtime error when accessing ctx->saved_data, but it works without the condition?

Assume the code looks like the following:

torch::autograd::variable_list SomeFunc::forward(torch::autograd::AutogradContext *ctx, torch::Tensor x) {
  /* ... */
  auto stream = c10::cuda::getCurrentCUDAStream();
  if(torch::GradMode::is_enabled()){
    // This will perform the forward computation in a different way using 
    // another slightly slower (in terms of inference only) formula having 
    // subexpressions that are in common with the backward computation, 
    // which makes the forward+backward fast.
    my_cuda_kernels::forward_and_precompute_for_backward(x, out, quantities_for_backward, stream);
    ctx->saved_data["quantities_for_backward"] = quantities_for_backward;
  } else {
    // The following will take a much faster path,
    // assumes we only need to do inference.
    my_cuda_kernels::forward_only(x, out, stream);
  }          
  return {out};
}

torch::autograd::variable_list SomeFunc::backward(torch::autograd::AutogradContext *ctx, torch::autograd::variable_list grad_output) {
  auto quantities_for_backward = ctx->saved_data["quantities_for_backward"].toTensor();
  /* do the backward computation by using "quantities_for_backward" to avoid recomputing some subexpressions that are common with the forward computation*/
}

The problem arises when I use this Function with autograd enabled (ie for example SomeFunc::apply(some_input).mean().backward() assuming some_input has requires_grad set to true), as during runtime libtorch will complain about getting a None instead of a Tensor when arriving at ctx->saved_data["quantities_for_backward"].toTensor();. So I assume that during the forward phase torch::GradMode::is_enabled() is not returning true.

When I remove the if statement in my forward method and just keep the version that precomputes the backward quantities and saves them in ctx in all cases, the code works again, which further confirms that torch::GradMode::is_enabled() is not working as intended inside a torch::autograd::Function.

Your issue makes sense, since gradient calculation is disabled in custom autograd.Functions by default.
You could check it in this example via print(torch.is_grad_enabled()) in the forward and would see that it’s returning False so you would have to manually enable it.
Sorry for missing it, but I thought the question is asking how to check if gradient calculation is enabled or not.

1 Like

So I guess then there’s no libtorch-“standard” way I can know inside the forward method whether the autograd::Function is being called from a context with autograd on? (except of course adding a flag myself that I manually set to true when I’m in training/backprop mode and false otherwise)

Okay I think I could just encapsulate it in a module and use the is_training method to help me decide which of the two forward versions I choose, this assumes I’ll have to call .train() and .eval() manually but at least it’s more “standard” than introducing my own flags. Probably something similar to BatchNorm.