Hello, the question is how can I calculate the reconstruct error at VAE in this case.
Now I’m trying to reconstruct inputs of shape (class, height, width) = (20, 100, 100) which I got from segmentation task (before argmax, let’s say logit).
I want to calculate the loss between input(logit) and output logit with cross-entropy but I don’t want to take argmax of both data.
Since taking the argmax will simply ignore other class’s infomation.
Ordinary I have to do approach like below right?
input.shape = (20, 100, 100) output.shape = (20, 100, 100) label = torch.argmax(x, 1) criterion = nn.CrossEntropyLoss() loss = criterion()
But I believe this way just comparing both’s argmax data.
So I come up with these codes.
def entropy(x, recon_x, h, w, batch, class_num, weight, ignore_id=19): softmax = nn.Softmax(dim=0) losses = torch.empty(h, w).to(device) for i in range(batch): loss = (softmax(x[i]) * softmax(recon_x[i])) for d in range(class_num): if d == ignore_id: continue weighted_loss = weight[d] * loss[d] torch.cat((losses, weighted_loss)) return torch.sum(losses)
But I’m not sure if this makes sense.
It’s so helpful if someone gives me an idea.