Hi,
I would like to know, for a multi class classification task, if it is required to pass labels’ dtype as ‘int’. If yes, what doesn’t it signify?
Hi LB!
For a multi-class classification task, one would typically use
CrossEntropyLoss
as the loss criterion. In recent versions of pytorch,
CrossEntropyLoss
supports two types of labels.
Let’s say that your prediction (the input
passed to CrossEntropyLoss
)
has shape [nBatch, nClass]
. Then the labels (the target
) can be
either integer categorical class labels of shape [nBatch]
and type
torch.int64
(which is long
) (but not type torch.int32
, which is int
).
Or they can be “soft” probabilistic labels of shape [nBatch, nClass]
and of the floating-point type that matches the type of input
(typically
float32
or float64
).
Best.
K. Frank