Is this a correct implementation and use of focal loss for binary classification on vision transformer output? If it is correct, why are all train and val preds still stuck at zero?

I can’t comment on the correctness of your custom focal loss implementation as I’m usually using the multi-class implementation from e.g. kornia.

As described in the great post by @KFrank here (and also mentioned by me in an answer to another of your questions) you either use nn.BCEWithLogitsLoss for the binary classification or e.g. nn.CrossEntropyLoss if you are treating the use case as a 2-class multi-class classification.
torch.argmax(output, dim=1) would return the predictions for the latter use case so you are currently mixing both approaches.
If your output has he shape [batch_size, 1] (as it should given you are using F.binary_cross_entropy_with_logits) you would have to apply a threshold to get the predictions e.g. via:

preds = output > 0.0
1 Like