Custom autograd function with extra flags

It’s possible to define a custom autograd function as shown here:

class Custom(Function):

    @staticmethod
    def forward(ctx, input):
        // Compute output
        return output

    @staticmethod
    def backward(ctx, grad_output):
       // Compute grad
        return grad

What should I do if I need my function to have a bool flag? That is,

class Custom(Function, flag):

    @staticmethod
    def forward(ctx, input, flag):
         if flag:
             // Compute output
         else:
            // Compute output
        ctx.flag = flag
        return output

    @staticmethod
    def backward(ctx, grad_output):
         flag = ctx.flag
         if flag:
             // Compute grad
         else:
            // Compute grad
        return grad

Just save the flag in ctx.

1 Like