Is there any implementation of EMD in pytorch?

I’m looking for a differential EMD (earth mover distance) function to measure the distance of a network latent.

1 Like

I’m not entirely sure I understand your use case enough to be certain that it’s what you’re looking for, but I can offer a notebook implementing entropy-regularized Wasserstein distances.

Best regards


I saw this work before,
My question is: let’s say my network output is a vector of 10 elements and my ground is also the same. I want my cost function to measure the distance to be the EMD.

What are your probability measures, then, to measure the distance between?
Usually, you have something like a PD of 10 elements just like you (conceptually) would have for classification and then KL divergence / Cross Entropy to the (peaked at target) distribution. That can be straightforwardly replaced by Wasserstein distance as Frogner et al do and (in the regularized case) can be done with my implementation.
Note that the whole thing is somewhat numerically sensitive.

Best regards


It is histogram measurement, So I want to use EMD to measure the their distance. I didn’t understand your KL / CE part ?

You can try this repository. I have been searching for this too, so I spent some time trying to update and generalize a few implementations I’d seen on GitHub. If you have a fix for the overflow issue too, I would be grateful :wink:

Hey! I came across this while searching for PyTorch EMD implementations, and I was wondering if this would work with input tensors with sizes of around (1, 16k, 3), so basically, batch size of 1, and 16k points that are represented as x, y, z. If not, would you happen to have any suggestions on how to implement some sort of EMD approximation myself using PyTorch? I’m not particularly concerned about speed yet, just some sort of implementation that can work.

None of the suggested methods worked for me yet. If you came a cross something better please post it back here.

Hi @tom,

Many thanks for making this implementation available to the community, I appreciate it a lot. I have almost figured out how to use it, but I have a problem, please forgive me if this is a very naive question.

After instantiating a criterion based on your WassersteinLossStab() class, if I try to compute the loss with a batch size of 1 it seems to work pretty consistently, but if I change the batch size to >1, it crashes:

batch_size = 1 # setting this to 2 breaks the loss
n_samples = 10


preds = torch.FloatTensor(batch_size,n_samples).uniform_()
preds = preds.float()

targets = torch.randn(batch_size, n_samples)>0
targets = targets.float()

criterion = WassersteinLossStab(torch.from_numpy(M), lam=0.1)

criterion(preds, targets)

Do you have any clue on what am I doing wrong here? Thanks!

Please try my implementation here for 3D point cloud research:

I just made a Pytorch wrapper for Haoqiang Fan’s implementation for paper: A Point Set Generation Network for 3D Object Reconstruction from a Single Image. Please cite this paper if you use the code.



this may help you, which contains solvers based on QPTH and opencv