When to implement backward()?

I am reading through the ‘Extending PyTorch’ that implements a new nn.Module and uses the function that has both forward() and backward() implemented. My question is if I want to define a new module that just adds a random weight to the input tensor and I define it like this:

class CustomLayer(nn.Module):
    def __init__(self, input_dim):
        super(CustomLayer, self).__init__()
        self.input_dim = input_dim
        self.weight = nn.Parameter(torch.Tensor(input_dim))
        self.weight.data.uniform_(-0.05, 0.05)

    def forward(self, input):
        return torch.add(input, self.weight)

Then do I need to implement a backward() as well? Because in the forward method, I am not using a Function, but just an already implemented operation in PyTorch.
Does having an nn.Parameter() in my contructor requires me to write a backward() function? Do they get updated without havin a backward()?
To my understanding if only you implement a custom function and you want to somehow manipulate the backward operations, then you need to write a backward() function. So for example, the module in the tutorial can be simply written as:

class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super(Linear, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(output_features))
        else:
            self.register_parameter('bias', None)

        self.weight.data.uniform_(-0.1, 0.1)
        if self.bias is not None:
            self.bias.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        return torch.add(torch.matmul(input, self.weight), self.bias)

without a backward() method!

Hi,

Yes you only need to write a backward function if:

  • You want to compute something that is not the “true” gradient that is computed by the autograd
  • Your forward function is not handled by the autograd (because you use custom function or third party library)

And the example in the tutorial is indeed artificial but it is just an example on how to do that for a simple function.

2 Likes

That makes sense. Thank you for the clarification!