Implementation of Squared Earth movers distance loss function for ordinal scale

Hi everyone, I recently came across the paper on “Squared earth mover’s distance-based loss for training deep neural networks.” ([1611.05916] Squared Earth Mover's Distance-based Loss for Training Deep Neural Networks). I want to use the squared EMD loss function for an ordinal classification problem .

However, I could not find a single implementation for the same. Though there is a tensorflow implementation which is as follows:

def earth_mover_distance(
        **kwargs
) -> Callable:
    """
    Wrapper for earth_mover distance for unified interface with self-guided earth mover distance loss.
    """
    import tensorflow as tf
    from tensorflow.keras import backend as K

    def _earth_mover_distance(
            y_true: K.placeholder,
            y_pred: K.placeholder
    ) -> K.placeholder:
        return tf.reduce_mean(tf.square(tf.cumsum(y_true, axis=-1) - tf.cumsum(y_pred, axis=-1)), axis=-1)

    return _earth_mover_distance

Can someone please help me with the implementation in pytorch?

You could try to translate this code by using the torch namespace.
E.g this might work (untested):

def earth_mover_distance(y_true, y_pred):
    return torch.mean(torch.square(torch.cumsum(y_true, dim=-1) - torch.cumsum(y_pred, dim=-1)), dim=-1)