Separating out forward & backward passes as tensor operations

Hi,

I am trying to separate out forward & backward passes of nn.module as tensor operations (I am working on a peculiar requirement here :slight_smile: )

The interfaces look like this.

class LinearForward:
    def __init__(self, in_features: int, out_features: int, bias=True, device=torch.device("cpu")):
    ...
    def __call__(self, input: Tensor) -> Tensor:
        self.input = input
        output = input.matmul(self.weight.t())
        if self.bias is not None:
            output += self.bias
        return output

class LinearBackward:
    def __init__(self, forward: LinearForward, device=torch.device("cpu")):
    ...
    def __call__(self, grad_output: Tensor) -> Tuple[Tensor, Tensor, Any]:
        grad_input = grad_output.matmul(self.weight)
        grad_weight = grad_output.t().matmul(self.forward.get_input())
        grad_bias = None
        if self.bias is not None:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

And I have tested the implementation against nn.Linear implementations and observed the following in cuda devices.

  1. my tensor implementation is about 10-15% slower than the nn.Linear
  2. In both cases, almost all the time is spent on the backward pass, while the forward takes <1% of the training time (at least for the Linear layer)

What would be the reason for this? Is it because my forward and backward classes are written in python and each __call__ method this getting synchronized with the python GIL (Like in the forward and backward hook implementations in nn.Modules)?

Look forward to hearing from the community.

PS:
I have followed these threads previously.

Hi,

Do you use torch.cuda.synchronize() properly when doing your timing? That sounds weird that the forward is <1%.

Hey @albanD,
Oh, you are correct. That was my mistake.
Could you help clarifying (1)?

The main reason I can think of is that when you use autograd, we only compute gradients for the things that require gradients. In particular, if the input does not require gradients, then autograd won’t compute grad_input which might explain some perf difference.
Also the function runs purely in cpp which might also explain a few % of difference, especially if you test with small input Tensors.