Variational AutoEncoder: Changing loss size_average to False

Hey !

I was trying to get a Variational Autoencoder to work recently, but to no avail. Classic pattern was that the loss would quickly decrease to a small value at the beginning, and just stay there.

VAEloss

Far from optimal, the network would not generate anything useful, only grey images with a slightly stronger intensity in the center
I could not spot my error, until I finally noticed that in the example the size_average was turned to False in the loss function concerning the reconstruction.

def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

I tried it out, and magically, results have improved.
I do not understand why it is so, especially since this part of the loss deals with the reconstruction and a simple autoencoder without this flag works just fine.

Could someone please explain this ?

Thanks a lot !

so wait, just to clarify: you set it to True or to False to get it to work?

I had to set it to False (:

i think the reason is because of the nature of the task (mind you this is going to be non math for my part). basically in a VAE you are asking the question ‘how closely did each reconstruction match?’ not on average how well did all my constructions match? so you want the ‘amount’ of error from a batch not the average error from a batch.

1 Like

Well, I’m not completely sure of that because a simple auto-encoder works just fine without this flag. However, when you take into account the variational inference, it just doesn’t work anymore without this specific instruction. That’s what puzzled me.

1 Like

I’ve bumped into the same thing. It really bothered me. I should be able to counteract this mean over batch using larger learning rate, but it didn’t seem to work. Setting too high LR for Adam just blow the training. Setting it to roughly the higher value before blowing training gives just blob of white pixels in the center.

BUT then I found this (https://github.com/pytorch/examples/issues/234), you can read there that binary cross entropy does mean over all dimensions (batch and spatial), so reconstruction loss gets divided by much larger value then KL loss divided only by batch size. This seems to be the problem, because KL loss averaged just over batch is much larger, then reconstruction loss averaged over batch AND spatial dimension. That’s why you just get a blob, because latent distribution overfit on standard gaussian and doesn’t carry any information to decoder. If you disable averaging and then divide by size of batch, everything works well with higher LR as expected :blush: Hope I helped!

2 Likes