Hi everyone, I recently came across the paper on “Squared earth mover’s distance-based loss for training deep neural networks.” ([1611.05916] Squared Earth Mover's Distance-based Loss for Training Deep Neural Networks ). I want to use the squared EMD loss function for an ordinal classification problem .

However, I could not find a single implementation for the same. Though there is a tensorflow implementation which is as follows:

```
def earth_mover_distance(
**kwargs
) -> Callable:
"""
Wrapper for earth_mover distance for unified interface with self-guided earth mover distance loss.
"""
import tensorflow as tf
from tensorflow.keras import backend as K
def _earth_mover_distance(
y_true: K.placeholder,
y_pred: K.placeholder
) -> K.placeholder:
return tf.reduce_mean(tf.square(tf.cumsum(y_true, axis=-1) - tf.cumsum(y_pred, axis=-1)), axis=-1)
return _earth_mover_distance
```

Can someone please help me with the implementation in pytorch?

1 Like

ptrblck
January 15, 2021, 11:13am
#2
You could try to translate this code by using the `torch`

namespace.
E.g this might work (untested):

```
def earth_mover_distance(y_true, y_pred):
return torch.mean(torch.square(torch.cumsum(y_true, dim=-1) - torch.cumsum(y_pred, dim=-1)), dim=-1)
```

1 Like

jyothi_sri
(jyothi sri)
October 17, 2022, 5:32pm
#3
Hello sir ,i tired the above code but i got the following problem
TypeError: cumsum() received an invalid combination of arguments - got (list, dim=int), but expected one of:

(Tensor input, int dim, *, torch.dtype dtype, Tensor out)
(Tensor input, name dim, *, torch.dtype dtype, Tensor out)
please help me

ptrblck
October 26, 2022, 6:08am
#4
As the error message explains: pass a `tensor`

to `torch.cumsum`

instead of a `list`

:

```
torch.cumsum([1., 2.], dim=0)
# TypeError: cumsum() received an invalid combination of arguments - got (list, dim=int), but expected one of:
# * (Tensor input, int dim, *, torch.dtype dtype, Tensor out)
# * (Tensor input, name dim, *, torch.dtype dtype, Tensor out)
torch.cumsum(torch.randn(10), dim=0) # works
```

jyothi_sri
(jyothi sri)
November 3, 2022, 3:46am
#5
will try .Thank you @ptrblck