Hello,
I am trying to implement a weighted CrossEntropyLoss with different weight for each pixel in the input image.
My code currently looks like this :
import torch
batch_size = 8
out_channels = 3
W = 128
H = 128
# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.FloatTensor(batch_size, H, W).random_(0, out_channels)
target = target.long()
# for exemple only, weight will have other values but keep this size
weights = torch.ones(batch_size, H, W)
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
raw_loss = loss_fn(logits, target)
weighted_loss = weights * raw_loss
loss = torch.sum(weighted_loss.flatten(start_dim=1), axis=0)
loss = torch.mean(loss)
I wanted to be sure that there aren’t any mistakes here.
Thank you.