Dtype=int for classification labels

I would like to know, for a multi class classification task, if it is required to pass labels’ dtype as ‘int’. If yes, what doesn’t it signify?

Hi LB!

For a multi-class classification task, one would typically use
CrossEntropyLoss as the loss criterion. In recent versions of pytorch,
CrossEntropyLoss supports two types of labels.

Let’s say that your prediction (the input passed to CrossEntropyLoss)
has shape [nBatch, nClass]. Then the labels (the target) can be
either integer categorical class labels of shape [nBatch] and type
torch.int64 (which is long) (but not type torch.int32, which is int).

Or they can be “soft” probabilistic labels of shape [nBatch, nClass]
and of the floating-point type that matches the type of input (typically
float32 or float64).


K. Frank