Hi,
Sure! here you go:
class GradReverse(Function):
@staticmethod
def forward(ctx, x, lambd):
ctx.lambd = lambd
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return (grad_output * -ctx.lambd), None
And you replace:
GradReverse(lambd)(inp)
with
GradReverse.apply(inp, lambd)