Why overhead? How to efficiently mimic `nn.Linear`?

The thing is that doing ctx.smth = some_input created a reference cycle and will leak memory. So you do want to use save_for_backard() for anything that is an input or output. It won’t do any copy and will be as light as ctx.xxx for other objects.
Also I just saw that you should not instantiate an instance of a Function but use the .apply static method. Refer to the doc link above on how to use a Function.