Hi,
I am trying to implement a standard gradient reversal layer which looks something like this:
class GradientReversalModule(nn.Module):
def __init__(self,lambd):
super(GradientReversalModule,self).__init__()
self.lambd = lambd
def forward(self,x):
return x
def backward(self,grad_value):
return -grad_value*self.lambd
I am just confused if I am supposed to inherit from nn.Module or torch.autograd.Function? And also, what is the difference between the two? Please do let me know, thanks!