Implementation of Squared Earth movers distance loss function for ordinal scale

Hi everyone, I recently came across the paper on “Squared earth mover’s distance-based loss for training deep neural networks.” ([1611.05916] Squared Earth Mover's Distance-based Loss for Training Deep Neural Networks). I want to use the squared EMD loss function for an ordinal classification problem .

However, I could not find a single implementation for the same. Though there is a tensorflow implementation which is as follows:

def earth_mover_distance(
        **kwargs
) -> Callable:
    """
    Wrapper for earth_mover distance for unified interface with self-guided earth mover distance loss.
    """
    import tensorflow as tf
    from tensorflow.keras import backend as K

    def _earth_mover_distance(
            y_true: K.placeholder,
            y_pred: K.placeholder
    ) -> K.placeholder:
        return tf.reduce_mean(tf.square(tf.cumsum(y_true, axis=-1) - tf.cumsum(y_pred, axis=-1)), axis=-1)

    return _earth_mover_distance

Can someone please help me with the implementation in pytorch?

1 Like

You could try to translate this code by using the torch namespace.
E.g this might work (untested):

def earth_mover_distance(y_true, y_pred):
    return torch.mean(torch.square(torch.cumsum(y_true, dim=-1) - torch.cumsum(y_pred, dim=-1)), dim=-1)
1 Like

Hello sir ,i tired the above code but i got the following problem
TypeError: cumsum() received an invalid combination of arguments - got (list, dim=int), but expected one of:

  • (Tensor input, int dim, *, torch.dtype dtype, Tensor out)
  • (Tensor input, name dim, *, torch.dtype dtype, Tensor out)
    please help me

As the error message explains: pass a tensor to torch.cumsum instead of a list:

torch.cumsum([1., 2.], dim=0)
# TypeError: cumsum() received an invalid combination of arguments - got (list, dim=int), but expected one of:
#  * (Tensor input, int dim, *, torch.dtype dtype, Tensor out)
#  * (Tensor input, name dim, *, torch.dtype dtype, Tensor out)

torch.cumsum(torch.randn(10), dim=0) # works

will try .Thank you @ptrblck