Implementation of the alpha-loss function

I came across this parametric loss function here: A Class of Parameterized Loss Functions for Classification: Optimization Tradeoffs and Robustness Characteristics | Lalitha Sankar

I was wondering if there is a simple way to implement this alpha-loss in pytorch for multi-class classification. I found an implementation on gitHub (link below), but it seems that it’s only for binary classification (?)

Implementation: AlphaLoss-TransIT-Code-MNIST-CIFAR/ at main · SankarLab/AlphaLoss-TransIT-Code-MNIST-CIFAR · GitHub