Custom autograd.Function: must it be static?

Hi,

Sure! here you go:

class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, lambd):
        ctx.lambd = lambd
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return (grad_output * -ctx.lambd), None

And you replace:

GradReverse(lambd)(inp)

with

GradReverse.apply(inp, lambd)
2 Likes