Hi everyone, I recently came across the paper on “Squared earth mover’s distance-based loss for training deep neural networks.” ([1611.05916] Squared Earth Mover's Distance-based Loss for Training Deep Neural Networks ). I want to use the squared EMD loss function for an ordinal classification problem .
However, I could not find a single implementation for the same. Though there is a tensorflow implementation which is as follows:
def earth_mover_distance(
**kwargs
) -> Callable:
"""
Wrapper for earth_mover distance for unified interface with self-guided earth mover distance loss.
"""
import tensorflow as tf
from tensorflow.keras import backend as K
def _earth_mover_distance(
y_true: K.placeholder,
y_pred: K.placeholder
) -> K.placeholder:
return tf.reduce_mean(tf.square(tf.cumsum(y_true, axis=-1) - tf.cumsum(y_pred, axis=-1)), axis=-1)
return _earth_mover_distance
Can someone please help me with the implementation in pytorch?
1 Like
ptrblck
January 15, 2021, 11:13am
2
You could try to translate this code by using the torch
namespace.
E.g this might work (untested):
def earth_mover_distance(y_true, y_pred):
return torch.mean(torch.square(torch.cumsum(y_true, dim=-1) - torch.cumsum(y_pred, dim=-1)), dim=-1)
1 Like
jyothi_sri
(jyothi sri)
October 17, 2022, 5:32pm
3
Hello sir ,i tired the above code but i got the following problem
TypeError: cumsum() received an invalid combination of arguments - got (list, dim=int), but expected one of:
(Tensor input, int dim, *, torch.dtype dtype, Tensor out)
(Tensor input, name dim, *, torch.dtype dtype, Tensor out)
please help me
As the error message explains: pass a tensor
to torch.cumsum
instead of a list
:
torch.cumsum([1., 2.], dim=0)
# TypeError: cumsum() received an invalid combination of arguments - got (list, dim=int), but expected one of:
# * (Tensor input, int dim, *, torch.dtype dtype, Tensor out)
# * (Tensor input, name dim, *, torch.dtype dtype, Tensor out)
torch.cumsum(torch.randn(10), dim=0) # works
jyothi_sri
(jyothi sri)
November 3, 2022, 3:46am
5
will try .Thank you @ptrblck