CrossEntropyLoss for multiple output classification

Given an input, I would like to do multiple classification tasks.

Each input needs to be classified into one of 5 classes. There are 6 such classification tasks to be done. I could build six separate Linear(some_number, 5) layers and return the result as tuple in the forward() function. Then call the loss function 6 times and sum the losses to produce the overall loss.

I am wondering if I could do this better than this. For example, can I have a single Linear(some_number, 5*6) as the output. Then reshape the logits to (6,5) and use. The documentation for CrossEntropyLoss mentions about “K-dimensional loss”. Is it for use cases as here? If so, can anyone share a sample code (When I tried to use I got confused about how to reshape the output).

Hi Suresh!

This is a perfectly reasonable approach and will work fine. It may well
introduce de minimis inefficiency, but any such cost will be negligible
compared with the cost of backpropagation.

This is also fine. I would probably use this second approach – not for
efficiency reasons, but because it seems a little more organized to me.

Note, for the “K-dimensional loss” case, you will want to reshape your
logits to have shape [nBatch, 5, 6] (not [nBatch, 6, 5]).

Let me use the following terms: nBatch (batch size), nClass = 5,
and nTask = 6.

The output of your model should have shape
[nBatch, nClass * nTask], you will reshape it, e.g.,
output.reshape (-1, nClass, nTask), to have shape
[nBatch, nClass, nTask], and then pass it as the input (prediction)
into CrossEntropyLoss. When you do this, the target you pass into
CrossEntropyLoss must have shape [nBatch, nTask] (no class
dimension), and consist of integer class labels that run from 0 to
nClass - 1.

(This second scheme only works, of course, if your multiple classification
tasks all have the same number of classes. If not, you would use your
first scheme.)


K. Frank

1 Like

Thanks for the detailed reply.

I am somehow finding this order less intuitive for me - nClass appearing before nTask .

I will be using torch.nn.functional.softmax to use the logits outside the loss function. I will use the dim parameter to go along the correct dimension.

Thanks again.