Sigmoid Giving only Negative Outputs

Hello All,
I am building an LSTM based classifier for EEG motor imagery Data for 2 classes. The data is from a 64 channel EEG and each channel has 20000 data points. The model is being trained for 50 epochs and converges for a decent loss. While testing the model for an individual file, the Sigmoid outputs only negative value irrespective of the class. Please tell me where could I be going wrong. Sorry for a trivial question, I am relatively new to the field and recently shifted to Pytorch
Attaching my code below


class pytorch_lstm(net.Module):
    def __init__(self, features, hidden_size, sequence_length, loss_function='BCE'):
        super(pytorch_lstm, self).__init__()
        self.loss = loss_function
        self.features = features
        self.hidden_size = hidden_size
        self.seq_length = sequence_length
        self.lstm = net.LSTM(
            input_size=self.features,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True
        )
        self.linear = net.Linear(self.hidden_size * self.seq_length, 1)

    def init_Hidden(self):
        hidden_state = torch.zeros(1, 1, self.hidden_size).cuda()
        cell_state = torch.zeros(1, 1, self.hidden_size).cuda()
        self.hidden = (hidden_state, cell_state)

    def forward(self, X):
        lstm_out, self.hidden = self.lstm(X, self.hidden)
        out = lstm_out.contiguous().view(1, -1)
        out = self.linear(out)
        return out

    def train_model(self, model, dataloader, num_epochs):
        least_loss = 10
        model.train()
        if self.loss == 'BCE':
            self.criterion = net.BCEWithLogitsLoss().cuda()
        elif self.loss == 'MSE':
            self.criterion = net.MSELoss().cuda()
        else:
            raise ValueError('PROVIDE A VALID LOSS FUNCTION')

        optimizer = torch.optim.Adam(model.parameters())
        training_loss = []
        for i in range(num_epochs):
            optimizer.zero_grad()
            st = time.time()
            epoch_loss = 0
            for _, (x, y) in enumerate(dataloader):
                model.init_Hidden()
                output = model(x)
                loss = self.criterion(output.view(-1), y)
                loss.backward()
                optimizer.step()
                print(loss.item())
                epoch_loss += loss.item()
            et = time.time()
            avg_loss = epoch_loss / 438
            print('----------------------------------------------------------')
            print('\033[1m' + 'TOTAL_TIME_PER_EPOCH = ' + str(et - st) + '\033[1m')
            print('\033[1m' + 'AVERAGE_LOSS = ' + str(avg_loss) + '\033[1m')
            training_loss.append(avg_loss)
            if avg_loss<least_loss:
                torch.save(model.state_dict(), 'path' + self.loss +'.pth')
                least_loss = avg_loss

        plt.plot(training_loss)
        plt.xlabel('EPOCHS')
        plt.ylabel('LOSS')
        plt.show()

The model outputs something like -

tensor([[-0.0545]], device='cuda:0') CLASS 1
tensor([[-0.0541]], device='cuda:0') CLASS 2
tensor([[-0.0522]], device='cuda:0') CLASS 1

Interestingly, the sigmoid outputs are quite close, Can someone please tell me why ?
Thanks in Advance

Hi Atharva!

First, I don’t see a sigmoid() anywhere in the code you posted.
(I suppose that there could be one hiding in the superclass.)

Next, your model outputs the result of a Linear layer, which, in
general, can produce negative values.

Last, your use of BCEWithLogitsLoss suggests that you are not
expecting positive probabilities produced by sigmoid(), but rather
unbounded (including negative) logits (“raw scores”) produced by
Linear.

If this doesn’t clear things up, please explain in more detail why you
are expecting your outputs to be positive.

Good luck!

K. Frank

Thanks a lot @KFrank.
I had read that BCEWithLogitsLoss() essentialy has a built sigmoid in it, so during the training it was alright, but when predicting for an individual input, I will have to use a sigmoid on the output. Is my understanding correct?

Hello Atharva!

It depends on what you mean by “predict.”

sigmoid() converts logits to probabilities. sigmoid (0) = 1/2.

If you wish to predict the value of a binary variable, that is, say
whether your best guess for the value of the variable is 0 or 1,
then it is sensible to say that if the probability of the variable
being 1 is greater than 1/2, you will predict 1; otherwise you
will predict 0.

So, prediction = (probability >= 1/2). But this is the
same as prediction = (logit >= 0). So you don’t need
the sigmoid() or the probability to make this kind of prediction.
Simply take the output of your model (which is the output of a
Linear, and should be understood as a logit) and compare it to 0.

But if by “predict” you mean that you want to state the probability
of your variable being 1, then you do, of course, need to convert
the logit to a probability by passing it through a sigmoid().

When you compute the accuracy of your model, you usually
compare 0-or-1-valued predictions with 0-or-1-value targets,
so you don’t need sigmoid() – you just compare your logits
with 0 to get the 0-or-1 (False-or-True) prediction.

When you compute the loss for your model, you often use
cross-entropy, which works with probabilities, so the computation
needs to convert the logits to probabilities. You can do this
yourself with an explicit sigmoid(), but for numerical reasons,
it is better to use BCEWithLogitsLoss which has (in effect) a
sigmoid() built into it.

Best.

K. Frank