VAE loss function (Cross entropy)

Hello, the question is how can I calculate the reconstruct error at VAE in this case.

Now I’m trying to reconstruct inputs of shape (class, height, width) = (20, 100, 100) which I got from segmentation task (before argmax, let’s say logit).

I want to calculate the loss between input(logit) and output logit with cross-entropy but I don’t want to take argmax of both data.
Since taking the argmax will simply ignore other class’s infomation.

Ordinary I have to do approach like below right?

input.shape = (20, 100, 100)
output.shape = (20, 100, 100)

label = torch.argmax(x, 1)
criterion = nn.CrossEntropyLoss()
loss = criterion()

But I believe this way just comparing both’s argmax data.
So I come up with these codes.

def entropy(x, recon_x, h, w, batch, class_num, weight, ignore_id=19):
    
    softmax = nn.Softmax(dim=0)
    losses = torch.empty(h, w).to(device)
    
    for i in range(batch):
        
        loss = (softmax(x[i]) * softmax(recon_x[i]))
        
        for d in range(class_num):
            if d == ignore_id:
                continue
                
            weighted_loss = weight[d] * loss[d]
            torch.cat((losses, weighted_loss))
            
    return torch.sum(losses)

But I’m not sure if this makes sense.
It’s so helpful if someone gives me an idea.

Thank you

Cross entropy loss considers all your classes during training/evaluation.

Argmax is used only to get the class prediction (the class with the highest probability), this is used only during inference, not training/evaluation.

Thank you for your reply @xian_kgx .

And I think I understand what you’re saying.

But the reason I took argmax for input data is that I had to reshape the data size to fit the expected target size for nn.Crossentropyloss.
https://pytorch.org/docs/stable/nn.html#crossentropyloss

So again, input data’s shape to VAE is
(Channel=Class, Height, Width) = (20, 100, 100)

Output data’s from VAE is
(Channel=Class, Height, Width) = (20, 100, 100)
And I want to minimize the loss between those two.

To match the expected target shape for nn.Crossentropyloss, I took argmax.
But the problem is, I don’t want to take argmax for input data since I want to keep each classes information. (Not only class which has the highest probability)

Thank you

Can you try BCE loss instead? It calculates the probability of a class for each of the 20 classes. Or you could use MSE loss.

Your loss from this line:

loss = (softmax(x[i]) * softmax(recon_x[i]))

doesn’t seem correct. For each coordinate you have a value in the range [0, 1] from your softmax function. If your prediction (for a particular coordinate) is 0 and your target is 0, you have 0 loss. And if your prediction is 1 and your target is 1 you have loss of 1.

I suggest a simple MSE or BCE loss will do.

I suggest a simple MSE or BCE loss will do.

I tried, but it keeps giving me a huge negative number like below.
Do you have any idea what is actually going on?
Input to VAE is a range of[0 - 1] and output has the same range of number too.

And also, MSE didn’t work.
Reconstructed data is not reproducing input data at all.
With changing the param for KLD, I could get better results.
But I don’t want to take this approach since it won’t be a general solution anymore if I changed the input data.

 41%|████      | 13/32 [01:04<01:27,  4.61s/it]
kld :  53.88439178466797 recon :  -53301.86328125
 44%|████▍     | 14/32 [01:08<01:22,  4.58s/it]
kld :  52.833587646484375 recon :  -56408.796875
 47%|████▋     | 15/32 [01:13<01:18,  4.59s/it]
kld :  51.79539489746094 recon :  -59532.7265625
 50%|█████     | 16/32 [01:17<01:12,  4.56s/it]
kld :  50.653541564941406 recon :  -62645.80859375
 53%|█████▎    | 17/32 [01:22<01:08,  4.54s/it]
kld :  49.04255676269531 recon :  -65793.15625
 56%|█████▋    | 18/32 [01:26<01:03,  4.52s/it]
kld :  47.729461669921875 recon :  -68937.03125
 59%|█████▉    | 19/32 [01:31<00:58,  4.51s/it]
kld :  47.20650100708008 recon :  -72088.6171875
 62%|██████▎   | 20/32 [01:35<00:54,  4.51s/it]
kld :  46.273780822753906 recon :  -75279.53125
 66%|██████▌   | 21/32 [01:40<00:49,  4.51s/it]
kld :  44.844871520996094 recon :  -78400.25
 69%|██████▉   | 22/32 [01:44<00:44,  4.49s/it]
kld :  42.69890213012695 recon :  -81610.5
・
・
・

What do you think about the function below?
Do you have any idea?


inputs.shape = (512, 19, 64, 64)
outputs.shape = (512, 19, 64, 64)

def entropy(inputs, outputs, weight,batch=batch_size, h=64, w=64, class_num=19, ignore_id=19):
    
    softmax = nn.Softmax(dim=0)
    losses = torch.empty(h, w).to(device)
    
    for b in range(batch):
        
        loss = - softmax(inputs[b]) * torch.log(softmax(outputs[b]))
        
        for c in range(class_num):
            if c == ignore_id:
                continue
                
            weighted_loss = ((weight[c]) * loss[c]) / batch_size
            torch.cat((losses, weighted_loss))
            
    return torch.sum(losses)

Why do you use CrossEntropy in a VAE? Arent you trying to reconstruct the input? in that case, either MSE or BCE should be used.

@Shisho_Sama

Do you know the theoretical reason why BCE, MSE is suited for VAE / AE loss function.
I found these article or paper saying like

’ it is usually used ’

or

’ There are two common loss functions used for training autoencoders, these include the mean-squared error (MSE) and the binary cross-entropy (BCE) ’

But I really don’t get why channel-wise cross entropy is not suited in my case.

I think its because In BCE and MSE, everything(all pixel values) is treated equally, whereas in your case, it means there is one true class only. in reconstructing the input, we therefore try to check each pixel value with its respective real value regardless of other pixels.

@Shisho_Sama

Thank you for your opinion.

whereas in your case, it means there is one true class only.

In this part, what if target wasnt encoded in one-hot, rather than that, what if input and target was continuous number [0 to 1] still model recognize some true class?

in VAE you dont use labels, you compare the actual input with the reconstructed version from decoder unless you want to use the conditional version which incorporates the one hot encoded label in the input and the latent variable z.