# Why overhead? How to efficiently mimic `nn.Linear`?

Hi! I am struggling to understand why the code below introduces any overhead (especially noticeable in backward) compared to `torch.nn.Linear`. What is suboptimal in my implementation and how to improve it?
(for some reasons I really need to wrap `torch.nn.functional.linear` in `MyLinearFunction` and `torch.nn.Linear` in `MyLinear`)

You can run the code in Colab:

If you have any problems with the link above, here is the same code:

``````import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor

def bench(module, n=1000):
forward_total = 0
backward_total = 0
for _ in range(n):
x = x.view(-1, 784)
now = time.perf_counter()
out = module(x)
forward_total += time.perf_counter() - now
loss = loss_fn(out, y)
now = time.perf_counter()
loss.backward()
backward_total += time.perf_counter() - now
return forward_total / n, backward_total / n

@staticmethod
def forward(ctx, x, weight, bias=None):
ctx.x = x
ctx.weight = weight
ctx.bias = bias
return F.linear(x, weight, bias)

@staticmethod
x, weight, bias = ctx.x, ctx.weight, ctx.bias

class MyLinear(nn.Linear):
def forward(self, x):
return MyLinearFunction.apply(x, self.weight, self.bias)

mnist_path = 'mnist_data'
mnist = MNIST(
mnist_path,
train=True,
transform=Compose([
ToTensor(),
lambda x: x[0]  # mnist images contain one channel
]),
)
loss_fn = nn.CrossEntropyLoss()

m1 = nn.Linear(784, 10)
m2 = MyLinear(784, 10)
m1.weight.data = m2.weight.data.clone()
m1.bias.data = m2.bias.data.clone()

vf, vb = bench(m1)
cf, cb = bench(m2)
print(f'vanilla forward: {vf:.6f}, vanilla backward: {vb:.6f}')
print(f'custom  forward: {cf:.6f}, custom  backward: {cb:.6f}')
``````

Hi,

I think you should use `save_for_backward()` to save input and output as done in the main doc.

Thank you, but it seems to have no effect in terms of performance, right?
Moreover, I have to prefer the dirty way (saving tensors as `ctx` attributes), because I am going to attach some information to these tensors in backward, so I need the original tensors, not copies. Like that:

``````@staticmethod
x, weight, bias = ctx.x, ctx.weight, ctx.bias
weight.some_info = get_info()

m = MyLinear(...)
...
loss.backward()
print(m.weight.some_info)
``````

If there are any better ways to achieve the same, I would be glad to know.

The thing is that doing `ctx.smth = some_input` created a reference cycle and will leak memory. So you do want to use `save_for_backard()` for anything that is an input or output. It won’t do any copy and will be as light as ctx.xxx for other objects.
Also I just saw that you should not instantiate an instance of a Function but use the `.apply` static method. Refer to the doc link above on how to use a Function.

I do not instantiate any instances of Function and I use the static method `apply` exactly how you describe in the `forward` method:

``````class MyLinear(nn.Linear):
def forward(self, x):
return MyLinearFunction.apply(x, self.weight, self.bias)
``````

Ho I missed the nn module wrapper, sorry !
Have you tried using `save_for_backward()` ? It should help reduce memory stress and potentially help with speed.

Unfortunately, using `save_for_backward` does not change the figures anyhow, so the overhead is still there. Here is the updated code:

Hi,

I’m not sure to understand exactly what happens here but it feels like you’re measuring more noise than anything:
Try adding outside of your function `data = torch.rand(128, 784), torch.rand(128).long()` and replace the dataloader step with `x, y = data`.
Just to reduce memory use on that side, then the backward pass runs 5x faster on my machine.
Now changing your custom backward to do nothing and `return None, None, None` actually has a similar runtime as the original linear layer’s one (with a lot of noise depending on the runs though).
So I would guess your custom backward looks slower because you do such small ops that calling 4 python ops are actually the most expensive thing that happens here. What do you think?

I updated the code so I create `x, y` the way you recommended:

Now, there are two implementations of my custom layer’s backward:
A)

``````x, weight, bias = ctx.saved_tensors
``````

B)

``````return None, None, None
``````

On my machine:
A
vanilla forward: 0.0000576, vanilla backward: 0.0001834
custom forward: 0.0000698, custom backward: 0.0002252

B
vanilla forward: 0.0000573, vanilla backward: 0.0001784
custom forward: 0.0000685, custom backward: 0.0001551

So, on my machine doing nothing in backward leads to faster execution (the difference looks statistically significant; I am surprised that on your machine matrix operations perform as fast as doing nothing).

On Colab the option B is also much faster than A.

Did you ran it multiple times?
The current results on google colab are:

Which would mean that they are ~the same with the actual code in the backward.

My point was more that running the script multiple times still gives too much variance to draw conclusion. This is what I get on my machine:

``````\$ python workspace/dev-pytorch/test/tmp.py
vanilla forward: 0.0001763, vanilla backward: 0.0003210
custom  forward: 0.0001498, custom  backward: 0.0003045
\$ python workspace/dev-pytorch/test/tmp.py
vanilla forward: 0.0001494, vanilla backward: 0.0002869
custom  forward: 0.0001435, custom  backward: 0.0003164
\$ python workspace/dev-pytorch/test/tmp.py
vanilla forward: 0.0001242, vanilla backward: 0.0002429
custom  forward: 0.0001492, custom  backward: 0.0002765
\$ python workspace/dev-pytorch/test/tmp.py
vanilla forward: 0.0001547, vanilla backward: 0.0002743
custom  forward: 0.0001138, custom  backward: 0.0002271
\$ python workspace/dev-pytorch/test/tmp.py
vanilla forward: 0.0001221, vanilla backward: 0.0002303
custom  forward: 0.0001534, custom  backward: 0.0002434
\$ python workspace/dev-pytorch/test/tmp.py
vanilla forward: 0.0001413, vanilla backward: 0.0002763
custom  forward: 0.0001706, custom  backward: 0.0003036
\$ python workspace/dev-pytorch/test/tmp.py
vanilla forward: 0.0001231, vanilla backward: 0.0002354
custom  forward: 0.0001527, custom  backward: 0.0002881
\$ python workspace/dev-pytorch/test/tmp.py
vanilla forward: 0.0001215, vanilla backward: 0.0001913
custom  forward: 0.0001504, custom  backward: 0.0002560
\$ python workspace/dev-pytorch/test/tmp.py
vanilla forward: 0.0001071, vanilla backward: 0.0002042
custom  forward: 0.0001372, custom  backward: 0.0002434

``````

Have you tried commenting `return None, None, None` in `backward`? If I comment it, `nn.Linear` is always faster than `MyLinear` on my machine and on Colab as well.

I tried to run the script multiple times with n=10000, the figures change from time to time, but the result of comparison is always the same.