Problem with CrossEntropyLoss

Hi everyone!

I’m trying to use Cross Entropy as my loss function, but there’s something wrong with dimensions of my output or target that I can’t notice.

Te output of my net is a single channel image with dimension [batch_size, 1, 320, 180]
The target is a single channel image with dimension [batch_size, 1, 320, 180]
Each pixel have values in the range [0,1].

Now I want to feed the CrossEntropyLoss function with them, but I don’t know which are the dimensions they have to have to be accepted by the function, so I’m getting this error:

RuntimeError: expected scalar type Long but found Float

I solved it by passing labels like: label.long()
(I don’t know if this is correct)

But I achieved the loss function to accept my labels.

The problem comes when the loss is always 0 and it does not train

Any idea which could be the problem?

For a binary segmentation use case, you could simply use nn.BCEWithLogitsLoss, which would accept the shapes you are currently providing.

nn.CrossEntropyLoss is used for a multi-class segmentation use case. It can also be used for a binary classification, but would require your output to have the shape [batch_size, 2, height, width] and the target to have the shape [batch_size, height, width] containing the class indices in the range [0, 1].

1 Like

Thanks for your answer, it helped me a lot.

It worked for me with nn.BCEWithLogitsLoss, but I would like to make it work with nn.CrossEntropyLoss also.

In my problem, the goal is to make the network output look as much as the target as possible. That makes me think that my problem is not a classification but a regression problem.
So, when I try to set the dimension of the network output like [batch_size, C, height, width] (being C = number of classes), I cannot set the number of classes because I don’t have them.

How should I use the nn.CrossEntropyLoss for a regression problem?

That won’t be possible, since you would need to provide class indices to calculate the cross entropy.
Based on “the goal is to make the network output look as much as the target as possible”, it seems indeed that you might be working on a regression problem, so you could use e.g. nn.MSELoss instead.

1 Like