Wasserstein loss layer/criterion


yes, conv double backward has been merged into master yesterday:
https://github.com/pytorch/pytorch/pull/1832 based on

I think most of the other functions are there as well.
Thanks to those involved for all the hard work!

Best regards


1 Like

Is BatchNorm supported? I am getting the error that batchnormbackward is not supported.


Hi Tom,

I have been trying out your improvements to WGAN (thanks!) - for lipschitz_constraint == 1 in https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/Improved_Training_of_Wasserstein_GAN.ipynb

Do you think taking the mean instead of the sum in this line: dist = ((vinput-vinput2)^2).sum(1)^0.5 would work better for high dimensional inputs as the Euclidean norm (esp for images) quickly exceeds the order of magnitude of the discriminator outputs.

My initial experiments appear to suggest the mean does work for images but not the sum.

1 Like


Your code is very clear, and helpful (at least for me).
Please let me understand this algorithm better. you backprop 1 for fake and -1 for real in this code:

While this notebook of @tom (which is enlightening as well) https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/Semi-Improved_Training_of_Wasserstein_GAN.ipynb backprops -1 for fake and 1 for real.

Maybe I mistook some calculations, but if I’m correct then we should decrease the output for real (since EM distance measures how far we are from the real distribution -> so for real we should decrease this) and increase the output for fake (where output is the output value of discriminator/critic). This means your code is correct, still I’m confused because @tom-s code is working as it seems to me.

Or is this some kind of a symmetric case where we only have to increase the distance between real and fake discrimination, no matter how?

Thank you :slight_smile:

Maybe you are right. In my opinion, the code in @tom regards the output of discriminator as the error while my code takes it as the Wasserstein Distance. They are optimized in the opposite gradient directions. I think maybe both ways will lead the algorithm converging, while I haven’t test it using my current code. (Maybe I will test it in a few days)

But for general purpose of understanding:
From the point of critic - we want to decrease Wasserstein Distance of the critics’ implicit probability density with regards to the real probability, while increase the Wasserstein Distance with regards to the **generator’**s implicit probability density. Am I right?

Yes, you are right. Explaining in this way is more understandable.:smiley:

Hi Csaba, Jarrel,

thank you for looking at this in detail!

I must admit that the mathematician in me cringes a bit @botcs’s argument.
As @jarrelscy mentions, this is symmetric (it is a distance after all).

What happens mathematically is that the discriminator - the test function in the supremum - will ideally converge to the negative of what you get when you switch the signs between real and fake. The only important thing is to have opposing signs between the pair (real discr, fake discr) in the discriminator and (fake discr, fake gen) in the generator, the latter is because we want to maximize the difference between the integrals in the discriminator but do so by minimizing the negative.

So approximately (if the penalty term were zero because the weight was infinite) the Wasserstein distance is the negative loss of the discriminator and the loss of the generator lacks the subtraction of the integral on the real to be the true Wasserstein distance - as this term does not enter the gradient anyway, is is not computed. This is independent of how you pick the signs.

I took the signs from the WGAN code published by Martin Arjowsky on github.

Note that some time after writing the notebook you linked, I arrived at the conclusion that one-sided loss is better.

Best regards


1 Like

Hi Jarrel,

Thank you for the observation

between this and the magnitude of the penalty parameter, it should be equivalent to move it. What you effectively are doing is changing the metric on the image space (from euclidean distance to euclidean distance / number of pixels).

That said, it would seem to be good to not have a penalty term that is overly large compared to the “primary terms”. My impression is that for the one-sided penalty larger weights are much less problematic, as witnessed by the toy problems.

Best regards


Hi Tom,

    I am so impressive about your code  https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/Pytorch_Wasserstein.ipynb. You have shown the example of 1D data, but I want to know how to calculate the Wasserstein loss  about 2D training dataset. How to show distance between two 2D matrix ? A 3D matrix? 

   Be interesting if you could show the example of 2D data.

Best regards,

F.S. Yang


great question!
It should work as for more dimensions in the sense that you just need to plug in the right distance function. The number of bins could be a limitation though, as with 100 datapoints you only get a 10x10 grid.
I should include a sample, really.

Best regards


Hi tom,

      Thank you for your reply. In computer vision, we often process 2D images, I find that computing the Wasserstein loss between two 2D matrix iteratively is so computational expensive. How can we deal with this problem? If we need to downsample the prediction and the target? And if the loss between downsampling matrix is not very reliable for training a network?

      I hope you will show us an example of 2D data.

Best regards,

F.S. Yang

I face the face problem

Hi Smth,

I hope you are well. I need to use wessesterian distance in my GAN, I used this code in pytorch from this link:


The x and y in ,my application are 64x9x9 patches, 64 is the batch size and 9x9 are the pixel intensity values, My question is that how I should pass the x and y, is 64x9x9 or 64x81x1 or 64x81? I tried all there is no error but numbers are different.

import torch
from geomloss import SamplesLoss 
# Define a Sinkhorn (~Wasserstein) loss between sampled measures
loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)
L = loss(x, y) 

Hi @ tom
I am trying to implement the WGAN , Would you please help me if you saw any mistakes?

optimizerD = optim.RMSprop(netD.parameters(), lr = 0.0002)
optimizerG = optim.RMSprop(netG.parameters(), lr = 0.0002)
class Generator(nn.Module):
    def __init__(self,ngpu,nz,ngf):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(self.nz, self.ngf * 8, 3, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 3, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 4),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( self.ngf * 4, self.ngf * 2, 3, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( self.ngf*2, 1, 3, 1, 0, bias=False),nn.Sigmoid()

    def forward(self, input):
        return self.main(input)
## -------Define discriminator ----------------
class Discriminator993(nn.Module):
    def __init__(self, ngpu,ndf):
        super(Discriminator993, self).__init__()
        self.ngpu = ngpu

        self.l1= nn.Sequential(nn.Conv2d(1, self.ndf, 3, 1, 0, bias=False),nn.LeakyReLU(0.2, inplace=True))
        self.l2=nn.Sequential(nn.Conv2d(self.ndf, self.ndf * 2, 3, 2, 0, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True))
        self.drop_out2 = nn.Dropout(0.5)

        self.l3= nn.Sequential(nn.Conv2d(self.ndf * 2, 1, 3, 2, 0, bias=False))

    def forward(self, x):

         out = self.l1(x)

         return out

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for pos in zip(trainloader):

        # (1) Update D network:
        ## Train with all-real batch
## -------Train Discriminator more -----------
        for Itr in range (CriticIt):
            real_cpu = images1.to(device)
            b_size = real_cpu.size(0)
            # Forward pass real batch through D
            output = netD(real_cpu).view(-1)
    ## ---------loss of the discriminator on real ------------
            errD_real = output.mean()
  # # -----------Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1, device=device)  
            # Generate fake image batch with G
            fake = netG(noise)        
            output = netD(fake.detach()).view(-1)
    ## ---------loss of the discriminator on fakes ------------
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real - errD_fake
            # Update D
#--------------------Cliping -------------------
            for p in netD.parameters():
                 p.data.clamp_(-0.01, 0.01)
        # (2) Update G network: 
        for p in netD.parameters():
            p.requires_grad = False # to avoid computation

        output = netD(fake44).view(-1)

        # Update G

I must admit I can never tell from looking at the code if there is an error. :slight_smile:
Also, I would recommend WGAN-GP (or SLOGAN linked above) over WGAN. And then there is SNGAN which I always thought of as being better than WGAN-GP, but it is unclear if the spectral normalization plays well with ReLU, see e.g. Anil et al., Sorting out Lipschitz Function Approximation, from ICML 2019.

Best regards


Hi Tom,

Sorry I need to use geomloss package and use SamplesLoss to compute the Sinkhorn (~Wasserstein) loss.
My datasets (X,Y) are 64 batches by 99 (heightwidth) patches means 6499 . would you please help me with that how I should pass them to the function . should be as one vector of 64* 9 *9 or i should compute for each 64 batch with vector size of 81 and then get sum?

loss = SamplesLoss(loss=“sinkhorn”, p=2, blur=.05)

L = loss(x, y)

Following error from geomloss gives hint on input shape:
ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.