I have a classification problem where each sample output consists of four values each of which is one of the 6 classes, i.e y has shape (batchsize, 4). So far, I have created a toy model which outputs a logits tensor out having the shape (batchsize, 6(=num_classes), 4). I am calculating the loss for each of the target value and summing up, followed by backprop:

If I understand what you are trying to do here (and I am not at all
sure that I do), then torch.nn.CrossEntropyLoss should do what
you want without further manipulation. (Look at the parts where
the documentation talks about the â€śK-dimensional case.â€ť)

Let the output of your model be a tensor of shape (batchsize, 6, 4).
I will call this tensor prediction. I will call the last dimension (the
â€ś4â€ť dimension) the â€śchannel.â€ť (So in your use case you have four
â€śchannels.â€ť) For a given sample in the batch and a given channel,
your prediction consists of six logits, one for each of your six classes.

You use your loss function to compare your prediction with your target. These are the known class labels that you use for training. target will be a tensor of shape (batchsize, 4), and for a given
sample within the batch and for a given channel (the â€ś4â€ť dimension)
you will have a single value that is an integer class label ranging
from 0 to 5.

(Note that prediction and target are not of the same shape: prediction carries all of the batchsize, class, and channel dimensions,
while target, because it uses a single integer class label per item,
carries only the batchsize and channel dimensions.)

By default, CrossEntropyLoss will average your loss over both the
batchsize and channel dimensions. If you prefer the sum, then pass
in the optional reduction = 'sum' argument when you construct
your CrossEntropyLoss loss-function object.

@KFrank I had a look over the â€śK-dimensional caseâ€ť in torch.nn.CrossEntropyLoss and it really helped solve my problem. What I was essentially doing can be done with criterion = torch.nn.CrossEntropyLoss(reduction='sum').