"predict_proba"

Hello,
I am working multi-class. When I study at Keras I can use “predict_proba” function for can see probability of every class.
Model is Sequential() and I used CNN.
code: y_score = model.predict_proba(testX)

I want to learn, is there a function at Pytorch like “predict_proba” .

Thank you for helping

Assuming you are working on a multi-class classification use case, you can pass the input to the model directly and check the logits, calculate the probabilities, or the predictions:

model.eval()
logits = model(data)
probs = F.softmax(logits, dim=1) # assuming logits has the shape [batch_size, nb_classes]
preds = torch.argmax(logits, dim=1)

Thank you so much. I will try and return back. You are amazing man.

And what will be the code, if we have single class problem, and using torch.sigmoid() ?

In that case you would use a threshold to get the predictions using a desired probability.
E.g.:

probs = torch.sigmoid(output)
preds = probs > 0.5 # adapt threshold if needed

Thanks for reply,

Because i want to compute sklearn.metrics.roc_curve(), so i am using my code like this, it is right?

output = torch.sigmoid( model(image) )

y_pred = output.cpu().numpy().flatten() >= 0.5

y_true = mask.flatten()

roc_curve(y_true , y_pred)

roc_curve expects the scores as an input, not the predictions, since the curve itself will be created using different thresholds.
From the docs:

y_score ndarray of shape (n_samples,)
Target scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions (as returned by “decision_function” on some classifiers).

According to scikit-learn

Target scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions (as returned by “decision_function” on some classifiers).

From Pytorch what output i can expect as scores? i think torch.sigmoid return probabilities?

Yes, you could pass the probabilities into this method created by sigmoid without applying the threshold.