Pass parameter in forward pass

My helper class looks like this:

class GradientReversalFunction(Function):
    """
    Gradient Reversal Layer from:
    Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)

    Forward pass is the identity function. In the backward pass,
    the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
    """

    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grads):
        lambda_ = ctx.lambda_
        lambda_ = grads.new_tensor(lambda_)
        dx = -lambda_ * grads
        return dx, None


class GradientReversal(torch.nn.Module):
    def __init__(self, lambda_=1):
        super(GradientReversal, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradientReversalFunction.apply(x)

and in my model I use

        self.discriminator = nn.Sequential(
            GradientReversal(),
            nn.Linear(40, 20),
            nn.ReLU(),
            nn.Linear(20, 10),
            nn.ReLU(),
            nn.Linear(10, 1),
        )

and in the forward pass I call

self.discriminator(x).squeeze()

Now instead of a fixed lambda_=1 I would like to call my forward function with varying values.

I have tried to update the forward call of the GradientReversal etc. but I only get errors.

Our Lord and Saviour ChatGPT actually pointed it out to me.

It does not suffice to update the forward method of the GradientReversal() to

def forward(self, x, lambda_):
    return GradientReversalFunction.apply(x, lambda_)

but I also have to explicitly grab the first entry of my discriminator

x_grad_rev = self.discriminator[0](x, alpha)
domain_predictions = self.discriminator[1:](x_grad_rev).squeeze()

What is the specific error message that you get when trying to add lambda_?