I have defined a custom autograd.Function with forward() and backward() overrided.
In each training step, my use case requires two autograd backward calls: one with retain_graph=True
, and another is for real loss backward.
The problem is, for the first backward() call I need to set a temporary flag in order to make it have some special behavior. Currently I have implemented it via setting a flag to a python object called native_module
that is stored into ctx
when calling forward(), and it seems to be working fine. I want to know if this is ensured to run safely.
The generic pseudo code looks like this:
import torch
class fun(torch.autograd.Function):
@staticmethod
def forward(ctx, native_module, input):
ctx.native_module = native_module
return ctx.native_module.fwd(input)
@staticmethod
def backward(ctx, doutput):
# Check the flag in native module
some_flag = ctx.native_module.some_flag
grad = ctx.native_module.bwd(doutput, some_flag)
return grad
class NativeModule:
some_flag: bool = False
def fwd(self, input):
# Some behavior
return ...
def bwd(self, doutput, some_flag: bool):
if some_flag:
# Some behavior
return ...
else:
# Some other behavior
return ...
# Train step
module = NativeModule()
for _ in range(100):
input = ...
output = fun.apply(module, input)
# First backward call (some_flag passed to module should be True)
module.some_flag = True
grad_in = torch.autograd.grad(output, input, ..., retain_graph=True)[0]
module.some_flag = False
loss = some_func(grad_in.data, output)
# Second backward call (some_flag passed to module should be False)
loss.backward()
# Some other procedure