How to implement a scikit-learn estimator in PyTorch

I had developed an estimator in Scikit-learn but because of performance issues (both speed and memory usage) I am thinking of implementing the estimator to run using GPU.

To do this, one way I can think of is to implement the estimator in PyTorch (so I can use GPU processing) and then use Google Colab to leverage on their cloud GPUs and memory capacity.

What would be the best way to implement an estimator which is already scikit-learn compatible in PyTorch?

Any pointers or hints pointing to the right direction would really be appreciated. Many thanks in advance.

I’m not sure how your estimator works, but you could probably just port the code to PyTorch.
A lot of numpy operations are implemented in PyTorch, so that you could change the namespace and run the PyTorch operation.

If you encounter some methods which are missing, let us know as there might be some workarounds.

Thanks @ptrblck. The estimator I have developed is a classifier. You can see it here if you’re interested:

Based on what you said about a large number of numpy operations are implemented in PyTorch, I don’t think I will have issues with the code as it is mostly all written using numpy arrays.

However, I am still not very clear on how to inherit from scikit-learn’s BaseEstimator and ClassifierMixin (which allows for methods like .fit(), .predit() etc.) into PyTorch. Any clues, directions or even better examples on this would really be appreciated. Many thanks in advance.

I’m not sure what the best way would be, but you might want to take a look at skorch which implements a sklearn-compatible API for PyTorch models, so I assume they have solved the issues. :wink:

Many thanks for your input @ptrblck. I had a look at skorch and if I am not mistaken, skorch implements sklearn-compatible APIs for “PyTorch models”, and not the other way round, ie. implement “sklearn models” for PyTorch, which is what I am after.

I guess I will have to google around more to see what I can find. Some folks over at reddit have suggested a few ideas which might work too.