CrossEntropyLoss - Why is it possible to have more columns in input than there are classes in the target?

[Changing the contents as the question wasn’t articulated well and I didn’t produce a minimum reproducible code]

As per documentation of cross_entropy (CrossEntropyLoss — PyTorch 2.1 documentation)

The input is expected to contain raw, unnormalized scores for each class. input has to be a Tensor of size (minibatch,C) for batched case. The target that this criterion expects should contain class indices in the range [0, C) where C is the number of classes

Why does the following code work -

input = torch.cat([torch.randn(3,5),torch.randn(3,3)],dim=1)
target = torch.empty(3,dtype=torch.long).random_(5)
F.cross_entropy(input,target)

Its correct as far as syntax is concerned. The logits are expected to be of shape [N,C] where N is the batch size and C is number of classes. The targets tensor is expected to be of shape [N] where each value is the numeric value of the expected/target class in the range [0,C).

In this case, it seems there are 100 classes.

There are only 50 classes. logits00 and logits01 are concatenated along dim=1 which makes C=100 for the value of input argument. Please note that 50 is an illustrative number here. The main idea is that logits00/01 are square tensors of shape, say, CxC. After concatenation, the shape becomes Cx2C. However, target tensor still have C classes only. I hope this makes sense now.

Could you please do a target.max() for a couple epochs and let me know the values. Also, which dataset are you using?

torch.max is 49 for target. I am using Stanford Online Product dataset

I assume the discussion is targeting the original post before the edit as the current code snippet uses 8 valid classes so I refer to:

This code snippet is working, as it meets the requirements.
The logits have a shape of [batch_size=3, nb_classes=8] while the targets have a shape of [batch_size=3] and contain values in [0, nb_classes-1]. In this case you are limiting the target values to [0, 4] which is inside the valid range of [0, 7].

1 Like