I came up with the answer: torch.sigmoid(classified_labels).data > 0.5
will give the correct labels with MultiLabelSoftMarginLoss()
.
3 Likes
I came up with the answer: torch.sigmoid(classified_labels).data > 0.5
will give the correct labels with MultiLabelSoftMarginLoss()
.