Can I specify backward() function in my custom layer by inheriting nn.Module?

I am trying to implement a standard gradient reversal layer which looks something like this:

class GradientReversalModule(nn.Module):
    def __init__(self,lambd):
        self.lambd = lambd
    def forward(self,x):
        return x
    def backward(self,grad_value):
        return -grad_value*self.lambd

I am just confused if I am supposed to inherit from nn.Module or torch.autograd.Function? And also, what is the difference between the two? Please do let me know, thanks!


An nn.Module is just a convenient construct to handle parameters, buffers, interaction with optimizers in the context of torch.nn

A autograd.Function is a new elementary op in the autograd.

So if you want to specify the backward for a given op, you want a custom autograd.Function. See the doc here on how to do that.

Thanks for letting me know @albanD!!