Why overhead? How to efficiently mimic `nn.Linear`?

Hi! I am struggling to understand why the code below introduces any overhead (especially noticeable in backward) compared to torch.nn.Linear. What is suboptimal in my implementation and how to improve it?
(for some reasons I really need to wrap torch.nn.functional.linear in MyLinearFunction and torch.nn.Linear in MyLinear)

You can run the code in Colab:

If you have any problems with the link above, here is the same code:

import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor

def bench(module, n=1000):
    forward_total = 0
    backward_total = 0
    for _ in range(n):
        x, y = next(iter(dataloader))
        x = x.view(-1, 784)
        now = time.perf_counter()
        out = module(x)
        forward_total += time.perf_counter() - now
        loss = loss_fn(out, y)
        now = time.perf_counter()
        backward_total += time.perf_counter() - now
    return forward_total / n, backward_total / n

class MyLinearFunction(torch.autograd.Function):
    def forward(ctx, x, weight, bias=None):
        ctx.x = x
        ctx.weight = weight
        ctx.bias = bias
        return F.linear(x, weight, bias)

    def backward(ctx, grad_out):
        x, weight, bias = ctx.x, ctx.weight, ctx.bias
        return grad_out @ weight, grad_out.t() @ x, grad_out.sum(dim=0)

class MyLinear(nn.Linear):
    def forward(self, x):
        return MyLinearFunction.apply(x, self.weight, self.bias)

mnist_path = 'mnist_data'
mnist = MNIST(
        lambda x: x[0]  # mnist images contain one channel
dataloader = DataLoader(mnist, batch_size=128)
loss_fn = nn.CrossEntropyLoss()

m1 = nn.Linear(784, 10)
m2 = MyLinear(784, 10)
m1.weight.data = m2.weight.data.clone()
m1.bias.data = m2.bias.data.clone()

vf, vb = bench(m1)
cf, cb = bench(m2)
print(f'vanilla forward: {vf:.6f}, vanilla backward: {vb:.6f}')
print(f'custom  forward: {cf:.6f}, custom  backward: {cb:.6f}')


I think you should use save_for_backward() to save input and output as done in the main doc.

Thank you, but it seems to have no effect in terms of performance, right?
Moreover, I have to prefer the dirty way (saving tensors as ctx attributes), because I am going to attach some information to these tensors in backward, so I need the original tensors, not copies. Like that:

def backward(ctx, grad_out):
    x, weight, bias = ctx.x, ctx.weight, ctx.bias
    weight.some_info = get_info()
    return grad_out @ weight, grad_out.t() @ x, grad_out.sum(dim=0)

m = MyLinear(...)

If there are any better ways to achieve the same, I would be glad to know.

The thing is that doing ctx.smth = some_input created a reference cycle and will leak memory. So you do want to use save_for_backard() for anything that is an input or output. It won’t do any copy and will be as light as ctx.xxx for other objects.
Also I just saw that you should not instantiate an instance of a Function but use the .apply static method. Refer to the doc link above on how to use a Function.

I do not instantiate any instances of Function and I use the static method apply exactly how you describe in the forward method:

class MyLinear(nn.Linear):
    def forward(self, x):
        return MyLinearFunction.apply(x, self.weight, self.bias)

Ho I missed the nn module wrapper, sorry !
Have you tried using save_for_backward() ? It should help reduce memory stress and potentially help with speed.

Unfortunately, using save_for_backward does not change the figures anyhow, so the overhead is still there. Here is the updated code:


I’m not sure to understand exactly what happens here but it feels like you’re measuring more noise than anything:
Try adding outside of your function data = torch.rand(128, 784), torch.rand(128).long() and replace the dataloader step with x, y = data.
Just to reduce memory use on that side, then the backward pass runs 5x faster on my machine.
Now changing your custom backward to do nothing and return None, None, None actually has a similar runtime as the original linear layer’s one (with a lot of noise depending on the runs though).
So I would guess your custom backward looks slower because you do such small ops that calling 4 python ops are actually the most expensive thing that happens here. What do you think?

I updated the code so I create x, y the way you recommended:

Now, there are two implementations of my custom layer’s backward:

x, weight, bias = ctx.saved_tensors
return grad_out @ weight, grad_out.t() @ x, grad_out.sum(dim=0)


return None, None, None

On my machine:
vanilla forward: 0.0000576, vanilla backward: 0.0001834
custom forward: 0.0000698, custom backward: 0.0002252

vanilla forward: 0.0000573, vanilla backward: 0.0001784
custom forward: 0.0000685, custom backward: 0.0001551

So, on my machine doing nothing in backward leads to faster execution (the difference looks statistically significant; I am surprised that on your machine matrix operations perform as fast as doing nothing).

On Colab the option B is also much faster than A.

Did you ran it multiple times?
The current results on google colab are:
Which would mean that they are ~the same with the actual code in the backward.

My point was more that running the script multiple times still gives too much variance to draw conclusion. This is what I get on my machine:

$ python workspace/dev-pytorch/test/tmp.py 
vanilla forward: 0.0001763, vanilla backward: 0.0003210
custom  forward: 0.0001498, custom  backward: 0.0003045
$ python workspace/dev-pytorch/test/tmp.py 
vanilla forward: 0.0001494, vanilla backward: 0.0002869
custom  forward: 0.0001435, custom  backward: 0.0003164
$ python workspace/dev-pytorch/test/tmp.py 
vanilla forward: 0.0001242, vanilla backward: 0.0002429
custom  forward: 0.0001492, custom  backward: 0.0002765
$ python workspace/dev-pytorch/test/tmp.py 
vanilla forward: 0.0001547, vanilla backward: 0.0002743
custom  forward: 0.0001138, custom  backward: 0.0002271
$ python workspace/dev-pytorch/test/tmp.py 
vanilla forward: 0.0001221, vanilla backward: 0.0002303
custom  forward: 0.0001534, custom  backward: 0.0002434
$ python workspace/dev-pytorch/test/tmp.py 
vanilla forward: 0.0001413, vanilla backward: 0.0002763
custom  forward: 0.0001706, custom  backward: 0.0003036
$ python workspace/dev-pytorch/test/tmp.py 
vanilla forward: 0.0001231, vanilla backward: 0.0002354
custom  forward: 0.0001527, custom  backward: 0.0002881
$ python workspace/dev-pytorch/test/tmp.py 
vanilla forward: 0.0001215, vanilla backward: 0.0001913
custom  forward: 0.0001504, custom  backward: 0.0002560
$ python workspace/dev-pytorch/test/tmp.py 
vanilla forward: 0.0001071, vanilla backward: 0.0002042
custom  forward: 0.0001372, custom  backward: 0.0002434

Have you tried commenting return None, None, None in backward? If I comment it, nn.Linear is always faster than MyLinear on my machine and on Colab as well.

I tried to run the script multiple times with n=10000, the figures change from time to time, but the result of comparison is always the same.