Hello,
I am training a logistic regression model with PyTorch - I’m new and have taken the code from somewhere else which uses ‘accuracy’ as a metric defined as follows:
> def accuracy(model, x, y):
> out = model(x)
> correct = torch.abs(y - out) < 0.5
> return correct.float().mean()
>
> plain_accuracy = accuracy(model, x_test, y_test)
> print(f"Accuracy on plain test_set: {plain_accuracy}")
Accuracy on plain test_set: 0.9837250709533691
It produces the same value when I use it with the in-built accuracy metric, which is as expected.
from torchmetrics.classification import Accuracy
accuracy = Accuracy(task="binary", average='macro', num_classes=2)
accuracy(model(x_test), y_test)
tensor(0.9837)
However, I want to use Precision and PrecisionRecallCurve as metrics.
My inputs are like this:
model(x_test)
tensor([[0.0898],
[0.0952],
[0.0879],
...,
[0.0911],
[0.0853],
[0.0902]], grad_fn=<SigmoidBackward0>)
y_test
tensor([[0.],
[0.],
[0.],
...,
[0.],
[0.],
[0.]])
type(y_test)
torch.Tensor
My Precision metric shows zero - tensor(0.), and I run into the error with the following code for PrecisionRecallCurve:
from torchmetrics.classification import PrecisionRecallCurve
pr_curve = PrecisionRecallCurve(task="binary")
precision, recall, thresholds = pr_curve(model(x_test), y_test)
precision
ValueError Traceback (most recent call last)
Input In [24], in <cell line: 3>()
1 from torchmetrics.classification import PrecisionRecallCurve
2 pr_curve = PrecisionRecallCurve(task="binary")
----> 3 precision, recall, thresholds = pr_curve(model(x_test), y_test)
4 precision
.............
File ~\anaconda3\lib\site-packages\torchmetrics\functional\classification\precision_recall_curve.py:136, in _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index)
133 _check_same_shape(preds, target)
135 if target.is_floating_point():
--> 136 raise ValueError(
137 "Expected argument `target` to be an int or long tensor with ground truth labels"
138 f" but got tensor with dtype {target.dtype}"
139 )
141 if not preds.is_floating_point():
142 raise ValueError(
143 "Expected argument `preds` to be an floating tensor with probability/logit scores,"
144 f" but got tensor with dtype {preds.dtype}"
145 )
ValueError: Expected argument `target` to be an int or long tensor with ground truth labels but got tensor with dtype torch.float32
Can someone explain to me whats wrong with this ValueError here?