I want to weight each pixel to compute my loss function. Now first I calculate cross entropy loss with reduce = False for the images and then multiply by weights and then calculate the mean. If I choose all the weights as 1, I should get a consistent result. But its not the case.
loss_function = torch.nn.CrossEntropyLoss(weight=weight, reduce=False)
loss_function_reduce = torch.nn.CrossEntropyLoss(weight=weight)
loss_tensor = loss_function(input, target)
weights = torch.ones(loss_tensor.size())
loss_tensor = (loss_tensor * weights).mean()
loss_tensor_reduced = loss_function_reduced(input, target)
Could someone help here ?