Using classifiers in scikit-learn with PyTorch

How can we use machine learning classifiers like SVM, Decision Tree, Random Forest, etc in PyTorch? Is there any library available?

You could apply these techniques using scikit-learn.
They work an numpy arrays, so if you would like to train an SVM on top of a CNN, you could just get the numpy array from the tensor:

output = pytorch_model(x)
output = output.numpy()

Yes, scikit-learn is an option. But it does not support GPU. Is there any way to use those classifiers on a GPU?

I think xgboost has some option to create the trees on the GPU, but I’m not sure, if a lot of these methods get a significant performance boost on the GPU.
Sebastian Raschka, as an author of a nice book on this topic, (@rasbt ) will most definitely know more on this topic. :wink:

Scikit-learn currently doesn’t have GPU support and is also not planning to add GPU support in foreseeable future. The reason is that there is no benefits to use a GPU for (most of) the algorithms that it implements. I would suspect it would even make things slower compared to efficient C++ libraries like LIBLINEAR and LIBSVM etc. For distributed parallelism (e.g., distributing the computation across multiple machines) there’s dask-ml, which has an API similar to sklearn:


Sorry for the late reply. Thanks for letting me know about desk-ml,

we are using LDA (linear discriminant analysis) in network forwardness and we net to keep the gradients but sci-kit learn library does not work with torch tensors with grad. is there any way to use this library in our code?

I don’t know how the LDA is implemented in scikit-learn, but you could either try to reimplement it using PyTorch operations and tensors so that Autograd could create the computation graph and would be able to automatically compute the gradients.
If that’s not possible e.g. since a specific numpy operation is not implemented in PyTorch, you would have to derive the backward function manually and could implement a custom autograd.Function as described in this tutorial.