It’s possible to define a custom autograd function as shown here:
class Custom(Function):
@staticmethod
def forward(ctx, input):
// Compute output
return output
@staticmethod
def backward(ctx, grad_output):
// Compute grad
return grad
What should I do if I need my function to have a bool flag? That is,
class Custom(Function, flag):
@staticmethod
def forward(ctx, input, flag):
if flag:
// Compute output
else:
// Compute output
ctx.flag = flag
return output
@staticmethod
def backward(ctx, grad_output):
flag = ctx.flag
if flag:
// Compute grad
else:
// Compute grad
return grad