Pixel Weight Map for CrossEntropyLoss


I am trying to implement a weighted CrossEntropyLoss with different weight for each pixel in the input image.

My code currently looks like this :

import torch

batch_size = 8
out_channels = 3
W = 128
H = 128

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.FloatTensor(batch_size, H, W).random_(0, out_channels)

target = target.long()

# for exemple only, weight will have other values but keep this size
weights = torch.ones(batch_size, H, W)

loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
raw_loss = loss_fn(logits, target)
weighted_loss = weights * raw_loss

loss = torch.sum(weighted_loss.flatten(start_dim=1), axis=0)
loss = torch.mean(loss)

I wanted to be sure that there aren’t any mistakes here.

Thank you.

The code looks generally alright. Some minor points:

  • Don’t use torch.FloatTensor, but the factory methods (e.g. torch.empty(...).normal_(), torch.randn(...) etc.)
  • I’m unsure about the logic to sum the loss in dim0 and calculating the mean afterwards, but assume it fits your use case (in e.g. weighted nn.CrossEntropyLoss you would normalize with the weights if reduction='mean' is used)