Support for one vs all classification in pytorch

I have been clinging to a problem with multi-class classification with a large number of classes (500) and a one vs rest classifier seems to be a good approach to handle it. Is there any support for the same in PyTorch itself?

All that I could find is this post which doesn’t have answers either. I understand that scikit-learn does offer this feature but I wish to stick to pytorch because of the numerous familiar customizations that can be tweaked within this framework. I’ve tried exploring skorch but even it’s overly simplified and doesn’t have good documentation to allow modifications which my use case needs

@ptrblck can you share any pointers here? Thanks in advance