Hello, the question is how can I calculate the reconstruct error at VAE in this case.
Now I’m trying to reconstruct inputs of shape (class, height, width) = (20, 100, 100) which I got from segmentation task (before argmax, let’s say logit).
I want to calculate the loss between input(logit) and output logit with cross-entropy but I don’t want to take argmax of both data.
Since taking the argmax will simply ignore other class’s infomation.
Ordinary I have to do approach like below right?
input.shape = (20, 100, 100)
output.shape = (20, 100, 100)
label = torch.argmax(x, 1)
criterion = nn.CrossEntropyLoss()
loss = criterion()
But I believe this way just comparing both’s argmax data.
So I come up with these codes.
def entropy(x, recon_x, h, w, batch, class_num, weight, ignore_id=19):
softmax = nn.Softmax(dim=0)
losses = torch.empty(h, w).to(device)
for i in range(batch):
loss = (softmax(x[i]) * softmax(recon_x[i]))
for d in range(class_num):
if d == ignore_id:
continue
weighted_loss = weight[d] * loss[d]
torch.cat((losses, weighted_loss))
return torch.sum(losses)
But I’m not sure if this makes sense.
It’s so helpful if someone gives me an idea.
Thank you