Is it safe to set flag for single autograd.Function.backward call

I have defined a custom autograd.Function with forward() and backward() overrided.

In each training step, my use case requires two autograd backward calls: one with retain_graph=True, and another is for real loss backward.

The problem is, for the first backward() call I need to set a temporary flag in order to make it have some special behavior. Currently I have implemented it via setting a flag to a python object called native_module that is stored into ctx when calling forward(), and it seems to be working fine. I want to know if this is ensured to run safely.

The generic pseudo code looks like this:

import torch

class fun(torch.autograd.Function):
    @staticmethod
    def forward(ctx, native_module, input):
        ctx.native_module = native_module
        return ctx.native_module.fwd(input)
    @staticmethod
    def backward(ctx, doutput):
        # Check the flag in native module
        some_flag = ctx.native_module.some_flag
        grad = ctx.native_module.bwd(doutput, some_flag)
        return grad

class NativeModule:
    some_flag: bool = False
    def fwd(self, input):
        # Some behavior
        return ...
    def bwd(self, doutput, some_flag: bool):
        if some_flag:
            # Some behavior
            return ...
        else:
            # Some other behavior
            return ...

# Train step
module = NativeModule()
for _ in range(100):
    input = ...
    output = fun.apply(module, input)
    
    # First backward call (some_flag passed to module should be True)
    module.some_flag = True
    grad_in = torch.autograd.grad(output, input, ..., retain_graph=True)[0]
    module.some_flag = False
    loss = some_func(grad_in.data, output)
    
    # Second backward call (some_flag passed to module should be False)
    loss.backward()
    
    # Some other procedure

Sounds ok, but why does your NativeModule bwd function take some_flag as input even though it can already access some_flag as a module field?

Thanks for your reply! Oh, that comes from an oversimplification of my actual use case. Sorry!

What I am uncertain of the most is that: when we call torch.autograd(), will the static backward() of autograd.Function be called instantly and synchronously? Will there be any asynchronous or delayed mechanisms to the backward() call?

Ah I see, no worries!
Yes, that is correct. I don’t expect there to be any weird interaction with the global here.