Cross entropy loss clarification

Here is a code snippet showing the PyTorch implementation and a manual approach.
Note that I’ve used for loops to show how this loss can be calculated and that the difference between a standard multi-class classification and a multi-class segmentation is just the usage of the loss calculation on each pixel. You should not use this code, as it’s way slower than the internal implementation (also the numerical stability is worse, as I haven’t used log_softmax):

# setup
criterion = nn.CrossEntropyLoss()

batch_size = 2
nb_classes = 4

output = torch.randn(batch_size, nb_classes)
target = torch.randint(0, nb_classes, (batch_size,))
loss = criterion(output, target)

# manual calculation
loss_manual = 0.
for idx in range(output.size(0)):
    # get current logit from the batch
    logit = output[idx]
    
    # get target from the batch
    t = target[idx]
    
    loss_manual += -1. * logit[t] + torch.log(torch.sum(torch.exp(logit)))

# calculate mean loss
loss_elements = output.size(0)
loss_manual = loss_manual / loss_elements
print(torch.allclose(loss, loss_manual))
> True

# for segmentation
h, w = 4, 4
output = torch.randn(batch_size, nb_classes, h, w)
target = torch.randint(0, nb_classes, (batch_size, h, w,))
loss = criterion(output, target)

# manual calculation
loss_manual = 0.
for idx in range(output.size(0)):
    for h_ in range(h):
        for w_ in range(w):
            # get current logit from the batch
            logit = output[idx, :, h_, w_]
            
            # get target from the batch
            t = target[idx, h_, w_]
            
            loss_manual += -1. * logit[t] + torch.log(torch.sum(torch.exp(logit)))

# calculate mean loss
loss_elements = (output.size(0) * output.size(2) * output.size(3))
loss_manual = loss_manual / loss_elements
print(torch.allclose(loss, loss_manual))
> True
1 Like