Defining backward() function in nn.module?

Is is possible to define the backward() function in nn.module?

Or do I have to use torch.autograd.Function? (as in https://github.com/jcjohnson/pytorch-examples#pytorch-defining-new-autograd-functions)

This module will have its trainable parameters. The backward gradient depend on the parameters. Does torch.autograd.Function support trainable prameters?
i.e. can I have torch.nn.Parameter(…) in torch.autograd.Function?

1 Like

An nn.Module can be seen as a container of parameters, calling in a forward method a list of operation processed on an input which are derivable wrt the parameters.

I am not sure to understand what you want to do, but if you define an autograd.Function like this:

class my_function(Function):
    def forward(self, input, parameters):
        self.saved_for_backward = [input, parameters]
        # output = [do something with input and parameters]
        return output

    def backward(self, grad_output):
        input, parameters = self.saved_for_backward
        # grad_input = [derivate forward(input) wrt parameters] * grad_output
        return grad_input

And then you define a module that call your backward:

class my_module(nn.Module):
    def __init__(self, ...):
        super(my_module, self).__init__()
        self.parameters = # init some parameters

    def backward(self, input):
        output = my_function(input, self.parameters) # here you call the function!
        return output

Then when you call backward you will derivate your parameters wrt to your function.

8 Likes

Thanks a lot for the detailed reply!!
What if both of the forward an backward functions in autograd.function need to modify the parameters?