I can’t comment on the correctness of your custom focal loss implementation as I’m usually using the multi-class implementation from e.g. kornia.
As described in the great post by @KFrank here (and also mentioned by me in an answer to another of your questions) you either use nn.BCEWithLogitsLoss
for the binary classification or e.g. nn.CrossEntropyLoss
if you are treating the use case as a 2-class multi-class classification.
torch.argmax(output, dim=1)
would return the predictions for the latter use case so you are currently mixing both approaches.
If your output has he shape [batch_size, 1]
(as it should given you are using F.binary_cross_entropy_with_logits
) you would have to apply a threshold to get the predictions e.g. via:
preds = output > 0.0