Why using LinearFunction in `Extending torch.nn`?

Hello everyone,

I’m a bit confused with how to properly extend torch.nn.

In basically all code samples I’ve seen, custom modules are compositions of already existing modules. But after playing around with tensors and autograd, it is tempting (if one wants flexibility) to just define parameters in a custom module’s __init__ and then perform some calculation in the forward function, similarly to how it is described here:


def forward(self, input):
    # See the autograd section for explanation of what happens here.
    return LinearFunction.apply(input, self.weight, self.bias)

Now – is it o.k. to drop LinearFunction and perform the calculation directly like this:

def forward(self, input):
    output = input.mm(self.weight.t())
    if self.bias is not None:
        output += self.bias.unsqueeze(0).expand_as(output)
    return output

It seem to work, but I’ve never seen this practice in code examples, so my question basically is whether there’s something wrong with this approach.

Thanks for clarification.


The difference is that in the first case, you have a single Function in the computational graph: LinearFunction, and so the backward will just call the backward for that Function.
In the second case, the computational graph contain few Functions: mm, transpose, addition, unsqueeze, expand… This means that the backward pass will require to traverse all these functions and call the backward for each of them. Moreover to be able to perform this backward pass, all intermediary results in between them will be saved in memory.

The first one is thus more efficient both speed and memory wise, but it requires you to implement the backward method by hand. The second one pay the cost of using the autograd to get the gradients without implementing the backward method explicitly.

1 Like

Thanks! That makes sense.