# CrossEntropyLoss for multiple output classification

Given an input, I would like to do multiple classification tasks.

Each input needs to be classified into one of 5 classes. There are 6 such classification tasks to be done. I could build six separate Linear(some_number, 5) layers and return the result as tuple in the forward() function. Then call the loss function 6 times and sum the losses to produce the overall loss.

I am wondering if I could do this better than this. For example, can I have a single Linear(some_number, 5*6) as the output. Then reshape the logits to (6,5) and use. The documentation for CrossEntropyLoss mentions about â€śK-dimensional lossâ€ť. Is it for use cases as here? If so, can anyone share a sample code (When I tried to use I got confused about how to reshape the output).

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

Hi Suresh!

This is a perfectly reasonable approach and will work fine. It may well
introduce de minimis inefficiency, but any such cost will be negligible
compared with the cost of backpropagation.

This is also fine. I would probably use this second approach â€“ not for
efficiency reasons, but because it seems a little more organized to me.

Note, for the â€śK-dimensional lossâ€ť case, you will want to reshape your
logits to have shape `[nBatch, 5, 6]` (not `[nBatch, 6, 5]`).

Let me use the following terms: `nBatch` (batch size), `nClass = 5`,
and `nTask = 6`.

The output of your model should have shape
`[nBatch, nClass * nTask]`, you will reshape it, e.g.,
`output.reshape (-1, nClass, nTask)`, to have shape
`[nBatch, nClass, nTask]`, and then pass it as the `input` (prediction)
into `CrossEntropyLoss`. When you do this, the `target` you pass into
`CrossEntropyLoss` must have shape `[nBatch, nTask]` (no class
dimension), and consist of integer class labels that run from `0` to
`nClass - 1`.

(This second scheme only works, of course, if your multiple classification
tasks all have the same number of classes. If not, you would use your
first scheme.)

Best.

K. Frank

1 Like

Frank,
I am somehow finding this order less intuitive for me - `nClass` appearing before `nTask` .
I will be using `torch.nn.functional.softmax` to use the logits outside the loss function. I will use the `dim` parameter to go along the correct dimension.