Custom Module Build on Top NN.linear

I would like to slightly modify the NN.linear Module and create my own module. Would it better to extend NN.module or to extend NN.linear and only implement methods which require changes?

Would you like to implement it in the C++ backend or as the first step just in Python?
In the latter case I would just write a custom module.

Just in Python for the first step. Since I am only modifing the forward function, I thought it might be better to extend NN.linear as opposed to extend NN.Module and having to copy the backprop operations. Is this a good approach?

code reuse is always a valid and suggested approach in my humble opinion. If you think most of the expected behavior is already present in nn.Linear, it makes sense to extend this and write your modifications as new class.

I’m completely on your side here and just thought the linear implementation would call into some ATen methods for performance reason, which doesn’t seem to be the case here.

I am new to python torch and I also have the need to create a custom non linear activation function. One of the first questions I have is if I should use only torch operators to create my function or if I can use modules like scipy or numpy in forward and backward methods ?

If you stick to PyTorch methods, Autograd will take care of your backward function, so that you won’t have to implement it yourself.
However, if you are using numpy methods or generally leave PyTorch, you would have to write the backward function yourself.

2 Likes

Thanks for your answer. This helps a lot.

Just for reference to whom it might be interested I found this examples of implementing forward and backwards methods with scipy and numpy. I have not tried them but based on prtblck’s answer I will definitely spend time on trying them out.

https://seba-1511.github.io/tutorials/advanced/numpy_extensions_tutorial.html

Update: I am using pytorch version 1.0.0. The numpy code example ran for me changing the line from numpy_input = input.numpy() to numpy_input = input.detach().numpy().

Thanks again.

1 Like

CREATING EXTENSIONS USING NUMPY AND SCIPY, PYTORCH VERSION 1.0.0

https://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html#creating-extensions-using-numpy-and-scipy

1 Like