Torch.nn loss functions

Hi,
my inputs into the loss functions are shaped as NCHW (8, 2, 256, 256) for y_pred and as NHW (8, 256, 256) for y_true. Here C mean number of classes. During the training phase, nn.CrossEntropyLoss() works well but other loss functions like L1Loss receives errors due to that number of classes. Please advise!

Some loss functions (CrossEntropyLoss, NLLLoss) take class “indices” as targets.

Others, KLDivLoss, L1Loss, … take probability distributions as the targets, so you’d need to translate your target to “one-hot” encoding. In the olden days, we used scatter for this, but now there is a utility function: torch.nn.functional.one_hot — PyTorch 1.9.1 documentation

Best regards

Thomas

Thanks Thomas for you reply.
I cannot solve it. Suppose we have torch.randn(1, 2, 256, 256).to(‘cpu’) that includes 2 channels with probability. First off, how we convert it to (1, 256, 256) and where we have to apply it in the network? I mean in the loss function or in the learner?

You want the y_true to have shape (1, 2, 256, 256) if you want to measure the L1 loss between measures. The one_hot way is to set (_, CLASS, _, ) to one and (, 1- CLASS, ,) to 0.

1 Like