My helper class looks like this:
class GradientReversalFunction(Function):
"""
Gradient Reversal Layer from:
Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
Forward pass is the identity function. In the backward pass,
the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
"""
@staticmethod
def forward(ctx, x, lambda_):
ctx.lambda_ = lambda_
return x.clone()
@staticmethod
def backward(ctx, grads):
lambda_ = ctx.lambda_
lambda_ = grads.new_tensor(lambda_)
dx = -lambda_ * grads
return dx, None
class GradientReversal(torch.nn.Module):
def __init__(self, lambda_=1):
super(GradientReversal, self).__init__()
self.lambda_ = lambda_
def forward(self, x):
return GradientReversalFunction.apply(x)
and in my model I use
self.discriminator = nn.Sequential(
GradientReversal(),
nn.Linear(40, 20),
nn.ReLU(),
nn.Linear(20, 10),
nn.ReLU(),
nn.Linear(10, 1),
)
and in the forward pass I call
self.discriminator(x).squeeze()
Now instead of a fixed lambda_=1
I would like to call my forward function with varying values.
I have tried to update the forward call of the GradientReversal
etc. but I only get errors.