Cross entropy for multivariate classification

I have a classification problem where each sample output consists of four values each of which is one of the 6 classes, i.e `y` has shape `(batchsize, 4)`. So far, I have created a toy model which outputs a logits tensor `out` having the shape `(batchsize, 6(=num_classes), 4)`. I am calculating the loss for each of the target value and summing up, followed by backprop:

``````out = model(in)
out.shape #(batchsize, 6, 4)

out1 = out[:,:,0].squeeze(-1)
out1.shape #(batchsize, 6)
out2 = out[:,:,1].squeeze(-1)
out3 = out[:,:,2].squeeze(-1)
out4 = out[:,:,3].squeeze(-1)
y1 = y[:,0].squeeze(-1)
y1.shape #(batchsize)
y2 = y[:,1].squeeze(-1)
y3 = y[:,2].squeeze(-1)
y4 = y[:,3].squeeze(-1)

total_loss = criterion(out1, y1) + criterion(out2, y2) + criterion(out3, y3) + criterion(out4, y4)

total_loss.backward()
optimizer.step()
``````

I want to know if this is the correct workaround for multivariate classification

Hi Hashir!

If I understand what you are trying to do here (and I am not at all
sure that I do), then torch.nn.CrossEntropyLoss should do what
you want without further manipulation. (Look at the parts where
the documentation talks about the â€śK-dimensional case.â€ť)

Let the output of your model be a tensor of shape `(batchsize, 6, 4)`.
I will call this tensor `prediction`. I will call the last dimension (the
â€ś4â€ť dimension) the â€śchannel.â€ť (So in your use case you have four
â€śchannels.â€ť) For a given sample in the batch and a given channel,
your `prediction` consists of six logits, one for each of your six classes.

You use your loss function to compare your `prediction` with your
`target`. These are the known class labels that you use for training.
`target` will be a tensor of shape `(batchsize, 4)`, and for a given
sample within the batch and for a given channel (the â€ś4â€ť dimension)
you will have a single value that is an integer class label ranging
from 0 to 5.

(Note that `prediction` and `target` are not of the same shape:
`prediction` carries all of the batchsize, class, and channel dimensions,
while `target`, because it uses a single integer class label per item,
carries only the batchsize and channel dimensions.)

By default, `CrossEntropyLoss` will average your loss over both the
batchsize and channel dimensions. If you prefer the sum, then pass
in the optional `reduction = 'sum'` argument when you construct
your `CrossEntropyLoss` loss-function object.

Best.

K. Frank

2 Likes

@KFrank I had a look over the â€śK-dimensional caseâ€ť in torch.nn.CrossEntropyLoss and it really helped solve my problem. What I was essentially doing can be done with `criterion = torch.nn.CrossEntropyLoss(reduction='sum')`.

Thanks for the help!