Cross-entropy loss in VAE

Hello there,

I’m currently trying to implement a VAE for dimensionality reduction purposes.

As a base, I went on from pytorchs VAE example considering the MNIST dataset.
My own problem however, does not rely on images, but on a 17 dimensional vector of continuous values.
I want to use the VAE to reduce the dimensions to something smaller. Additionally, I use a “history” of these values, to transport information about previous values into the network.

So, overall i get an input for my network, which is of the following shape:
[BATCH_SIZE x HISTORY_LENGTH x FEATURE_SIZE]

In a small example I used the following:
[64 x 3 x 17]

Long story short, when it comes to the calculation of the loss function (I started from the pytorch example https://github.com/pytorch/examples/blob/master/vae/main.py#L72) I get into some problems.

Due to the fact that I don’t use images as input, I changed the binary cross-entropy to normal cross-entropy, utilizing F.cross_entropy(reconstructed_x, x, reduction='sum') instead of F.binary_cross_entropy(reconstructed_x, x, reduction='sum') and now I get the following error message:

ValueError: Expected target size (64, 17), got torch.Size([64, 3, 17])

Can somebody tell me what I am missing?

Thanks in advance!

nn.CrossEntropyLoss expects targets passed as a LongTensor containing class indices, as it’s usually used for multi-class classification use cases.

I’m not sure, how your input tensor is defined, but I assume it contains float values.
In this case, you won’t be able to use nn.CrossEntropyLoss to reconstruct the input tensor.

Was the mentioned loss function from the VAE implementation not working?

Your assumption is correct, my input tensor consists of float values, so this explains to me why it is not working properly.

The F.binary_cross_entropy() function produces very high losses, so I tried with F.mse_loss() instead, which produced better loss values but I assumed that this is not correct, since the MSE loss is not mentioned in the Kingma paper (https://arxiv.org/pdf/1312.6114.pdf, see section 3 loss function).

So overall I`m not sure which loss function would suit my case best. I am open to any suggestions at this point.

Maybe there other things, that are likely to be the source of high losses, but I still have to find them.

Thanks for your reply :slight_smile: