Hello, the question is how can I calculate the reconstruct error at VAE in this case.

Now I’m trying to reconstruct inputs of shape (class, height, width) = (20, 100, 100) which I got from segmentation task (before argmax, let’s say logit).

I want to calculate the loss between input(logit) and output logit with cross-entropy but I don’t want to take argmax of both data.

Since taking the argmax will simply ignore other class’s infomation.

Ordinary I have to do approach like below right?

```
input.shape = (20, 100, 100)
output.shape = (20, 100, 100)
label = torch.argmax(x, 1)
criterion = nn.CrossEntropyLoss()
loss = criterion()
```

But I believe this way just comparing both’s argmax data.

So I come up with these codes.

```
def entropy(x, recon_x, h, w, batch, class_num, weight, ignore_id=19):
softmax = nn.Softmax(dim=0)
losses = torch.empty(h, w).to(device)
for i in range(batch):
loss = (softmax(x[i]) * softmax(recon_x[i]))
for d in range(class_num):
if d == ignore_id:
continue
weighted_loss = weight[d] * loss[d]
torch.cat((losses, weighted_loss))
return torch.sum(losses)
```

But I’m not sure if this makes sense.

It’s so helpful if someone gives me an idea.

Thank you