I’m trying to modify Yolo v1 to work with my task which each object has only 1 class. (e.g: an obj cannot be both cat and dog)
Due to the architecture (other outputs like localization prediction must be used regression) so sigmoid was applied to the last output of the model (f.sigmoid(nearly_last_output)). And for classification, yolo 1 also use MSE as loss. But as far as I know that MSE sometimes not going well compared to cross entropy for one-hot like what I want.
And specific: GT like this: 0 0 0 0 1 (let say we have only 5 classes in total, each only has 1 class so only one number 1 in them)
and output model at classification part: 0.1 0.1 0.9 0.2 0.1
I found some suggestion use nn.BCE / nn.BCEWithLogitsLoss but I think I should ask here for more correct since I’m not good at math and maybe I’m wrong somewhere so just ask to learn more and for sure what should I use.
I don’t fully understand your use case, but I would proceed as follows:
Use CrossEntropyLoss for the classification part of your model. To
do this, your model should output raw-score logits,, not probabilities,
so the last layer of your model should most likely be a Linear layer
with nClass outputs. Don’t feed this through a sigmoid() (nor a softmax()). Your target (ground truth) should not be in one-hot
format, but rather integer class labels that run from 0 to nClass - 1.
(You could try to undo the sigmoid() that is being applied to your
“nearly-last” layer, but this risks becoming numerically unstable.)
But if you need the sigmoid() for other parts of your loss function (You
talk about localization, regression, and MSE.), you should incorporate
the sigmoid() into that part of your loss function, rather than into your
model. (It is numerically better to apply the sigmoid() to your logits
in the MSE part of your loss, that to try to undo a sigmoid() in the CrossEntropyLoss part of your loss.)
(BCELoss is not appropriate for the classification part of your model
as you are working on a single-label, multi-class problem, not on
a so-called multi-label, multi-class problem.)