Loss function for binary classification with Pytorch

Hi everyone,

I am trying to implement a model for binary classification problem. Up to now, I was using softmax function (at the output layer) together with torch.NLLLoss function to calculate the loss. However, now I want to use the sigmoid function (instead of softmax) at the output layer. If I do that, should I also change the loss function or may I still use torch.NLLLoss function?

Suppose following simple model:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 2)

    def forward(self, x):
        x = linear1(x)
        x = linear2(x)
        return x

you can add nn.LogSigmoid() layer to get sigmoid(x):

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = linear1(x)
        x = linear2(x)
        x = sigmoid(x)
        return x

If you use binary cross entropy loss, you can compute loss as:

model = Net()
y = model.forward(input)
loss = - t*log(y) - (1-t)*log(1-y) 

postscript
Modified the LogSigmoid to Sigmoid

For the sake of completeness: you can also use nn.Sigmoid as the output layer and nn.BCELoss in case you don’t want to write the formula yourself.

7 Likes

The above comment confused me a little bit. If I want to use nn.BCELoss, should I take the log of nn.Sigmoid inside the forward function ?

Sorry for the confusion. No, you should just use a sigmoid on your output, if you are using nn.BCELoss.

Also, I’m not sure @kenmikanmi’s approach will work, as the second term seems to have a small mistake.
The second term should look like: (1 - t) * log(1 - sigmoid(x)), while currently the formula uses (1 - t) * (1 - logsigmoid(x)).

@coyote I’m sorry for mistake and modified the first comment.(please use nn.Sigmoid() instead)

As @ptrblck first mentioned, I also consider using nn.BCELoss is better way when you don’t use custom loss function.

1 Like