Hi Shaun!
Yes, this is correct. To give a little more detail, y_pred
is a
number that indicates how strongly (in some units) the model
is predicting class = “1”. This number is typically called the
logit. probs = torch.sigmoid(y_pred)
is the predicted
probability that class = “1”. And predicted_vals
is the
predicted class label itself (0 or 1).
As a practical matter, you don’t need to calculate sigmoid
.
You can save a little bit of time (but probably trivial) by
leaving it out.
If threshold were 0.5 (that is, predict class = “1” when
P(class = “1”) > 1/2), then you could use
predicted_vals = y_pred > 0
.
More generally, you can compare y_pred
with the
inverse-sigmoid of the threshold you want. This is typically
called the logit function, and is given by log (p / (1 - p)
.
(Given a probability value p, 0 < p < 1, inverse-sigmoid (p) =
logit (p) = log (p / (1 - p)).)
So:
logit_threshold = torch.tensor (threshold / (1 - threshold)).log()
...
predicted_vals = y_pred > logit_threshold
That is, instead of applying sigmoid
to all of your y_pred
s,
you calculate inverse-sigmoid
of threshold
once.
Best.
K. Frank