Maybe the answer to this stackoverflow question is helpful,
In mathematical terms, what exactly do you want to do? That might be easier for people to help you with, rather than trying to port over a TF function?
If you want to do multi-label classification, so do I, but I haven’t figured out yet how to do it in PyTorch? So I’m also interested in your question
Best,
Ajay