[QuickQuestion] Is this the correct way to use crossEntropy function?

Say, I have 3 classes, and my output looks like this: (Batch x channel x H x W x D) 32 x 3 x 25 x 25 x 25.
Since the cross entropy function requires 2D input (N,C), I changed my input to this, using .view():
(500000,3)

500000 is because 32x25x25x25.

I also changed the label to this dimension: (500000).

Then I put them to the loss function: loss = criterion(output,label).

Is it correct to do so? Thanks in advance!!

Hi
The cross entropy loss takes as arguments the scores for each class that your model predicted in the forward pass and the ground truth to compare them. This is true per example in a minibatch, so the input to the cross entropy loss should have dimensions (N,C) for the scores your model predicted and (N) for the true labels. Here N represents the size of your minibatch, so if your minibatch has 32 elements then N=32 for the inputs and targets. Similarly C represents the number of classes so if you have 3 classes C=3 and each element of this dimension will represent the score that your model assigned to that particular class. Thus the input with dimensions (N, C) would have C number of scores for N examples in your minibatch. And for the values you provided your inputs to the loss function would be (32, 3) and (32).
However for the dimensions of your output correspond to volumetric data which is probably not what you want to use with cross entropy. Cross entropy is used to calculate the loss of a predicted class among a fixed number of classes. Typically this is done after one or more fully connected layers taking as input a flattened (done with view()) version of the preceding (typically convolutional) layers. I don’t know what you’re trying to do, but likely you might want to first flatten and then calculate the scores for each class.
Unless you are trying to classify each data point (each pixel) within a single example and each entry of the channel dimension is the score of each class, then yeah what you’re doing is right. Just be careful to properly align each point to its matching label.

Final note, in either case remember that if you have 3 classes the values of your labels will be {0, 1, 2}.

Hi, @RicCu thank you so much for the detailed reply!!
Yes my network is an encoder-decoder architecture and it will output pixel-wise predictions. Sorry I should’ve mentioned that at the first place.

Thanks again! :slight_smile: