Hi! I am struggling to understand why the code below introduces any overhead (especially noticeable in backward) compared to torch.nn.Linear. What is suboptimal in my implementation and how to improve it?
(for some reasons I really need to wrap torch.nn.functional.linear in MyLinearFunction and torch.nn.Linear in MyLinear)
Thank you, but it seems to have no effect in terms of performance, right?
Moreover, I have to prefer the dirty way (saving tensors as ctx attributes), because I am going to attach some information to these tensors in backward, so I need the original tensors, not copies. Like that:
The thing is that doing ctx.smth = some_input created a reference cycle and will leak memory. So you do want to use save_for_backard() for anything that is an input or output. It won’t do any copy and will be as light as ctx.xxx for other objects.
Also I just saw that you should not instantiate an instance of a Function but use the .apply static method. Refer to the doc link above on how to use a Function.
Ho I missed the nn module wrapper, sorry !
Have you tried using save_for_backward() ? It should help reduce memory stress and potentially help with speed.
I’m not sure to understand exactly what happens here but it feels like you’re measuring more noise than anything:
Try adding outside of your function data = torch.rand(128, 784), torch.rand(128).long() and replace the dataloader step with x, y = data.
Just to reduce memory use on that side, then the backward pass runs 5x faster on my machine.
Now changing your custom backward to do nothing and return None, None, None actually has a similar runtime as the original linear layer’s one (with a lot of noise depending on the runs though).
So I would guess your custom backward looks slower because you do such small ops that calling 4 python ops are actually the most expensive thing that happens here. What do you think?
So, on my machine doing nothing in backward leads to faster execution (the difference looks statistically significant; I am surprised that on your machine matrix operations perform as fast as doing nothing).
Have you tried commenting return None, None, None in backward? If I comment it, nn.Linear is always faster than MyLinear on my machine and on Colab as well.
I tried to run the script multiple times with n=10000, the figures change from time to time, but the result of comparison is always the same.