Proper way of doing binary classification with one probability output ? (what loss function/activation function to use and how to compute accuracy ? )

Hello everyone, sorry for rookie question I’m starting to learn pytoch

this is my simple 1 layer linear classifier :

class Classifier(nn.Module):
    def __init__(self, in_dim ):
        super(Classifier, self).__init__()
        self.classify = nn.Linear(in_dim  , 1 )

    def forward(self, features ):

        final = torch.sigmoid ( self.classify(features) )
        return  final

I want the output to be probability, so ~1 means class 1 and ~0 means class 0

but I don’t know which loss function to use and how to calculate the accuracy in each epoch when I’m using batching ?

This is my current training loop but the loss is not correct, i feel like i need to change the code because this code is written for multi class classification, not a single output classification :

  

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

net.train()


running_loss = 0
total_iters = len(trainloader)

for pos, (train_samples, labels) in zip(bar, trainloader):

    outputs = net(train_samples)

    loss = criterion(outputs, labels.float() )
    running_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


running_loss = running_loss / total_iters

return running_loss
1 Like

For the loss function, switch out CrossEntropyLoss for BCELoss. I usually like to write a separate function that computes the accuracy (over the whole set) and use that within my training loop. The function takes in as arguments the model and the train dataloader.

Hi Richard!

As Prashanth notes, you could use BCELoss in place of
CrossEntropyLoss.

However, you’ll be better off removing the torch.sigmoid()
and using BCEWithLogitsLoss. Doing so will be mathematically
the same, but numerically more stable.

Thus:

class Classifier(nn.Module):
    def __init__(self, in_dim ):
        super(Classifier, self).__init__()
        self.classify = nn.Linear(in_dim  , 1 )

    def forward(self, features ):

        final = self.classify(features) 
        return  final

and:

criterion = nn.BCEWithLogitsLoss()

Your Classifier will now output raw-score logits that range from
-inf to inf instead of probabilities. Should you need probabilities
for subsequent processing, you can always pass the logits through
sigmoid(). Note, you don’t need probabilities to make hard 0-1
predictions: prediction = 1 if logit > 0.0 is the same as
prediction = 1 if probability > 0.5.

Two side comments:

As written, you never call scheduler.step() so scheduler doesn’t
do anything.

For getting started with the code, one Linear layer is fine, but it
won’t be much of a classifier for anything but special toy problems.
Leaving aside the sigmoid(), your single output is just a linear
function of your in_dim inputs. Things already get much more
interesting (and useful) if you add a single “hidden” layer:

class Classifier (nn.Module):
    def __init__ (self, in_dim, hidden_dim):
        super (Classifier, self).__init__()
        self.fc1 = nn.Linear (in_dim, hidden_dim)
        self.activation = nn.ReLU()   # for example
        self.fc2 = nn.Linear (hidden_dim, 1)
    def forward (self, features):
        x = self.fc1 (features)
        x = self.activation (x)
        x = self.fc2 (x)
        return x

For more interesting classification tasks, the non-linear activation
(for example, ReLU) between fc1 and fc2 is the “secret sauce.”

Best.

K. Frank

3 Likes