LSTM classifier always predicts same probability for binary text classification

I’m trying to implement an LSTM NN to classify spam and non-spam text. It seems that the model is not trained and the loss does not change over epochs, so it always predicts the same values. At the latest time, it predicts [ 0.4950] for all test samples so it always predicts class as 0. The number of EPOCHs is 50 and LR is 0.0001 with adam (and SGD) optimizer (I tried 0.001 as LR but I got the same results). I’m really confused about the reason for this issue. What is the problem?
my classifier is:

class LSTM_clf(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, num_classes, batch_size,max_seq_len):
        super(LSTM_clf, self).__init__()  
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.max_seq_len=max_seq_len
        self.gpu= torch.cuda.is_available()
        self.embedding = nn.Embedding(vocab_size, embed_size)#.cuda()# Initializing the look-up table.
        
        # self.embedding.load_state_dict({'weight': torch.Tensor(emb_weights)}) # Assigning the look-up table to the pre-trained GloVe word embedding.

        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, dropout=0.3, bidirectional=True)
        self.lstm2out = nn.Linear(2*hidden_size, vocab_size)
        self.hidden2out = nn.Linear(2*hidden_size, 1)
        self.hidden=self.init_hidden()
        self.fc = nn.Sequential(
            nn.Linear(2*hidden_size, 100),
           # nn.ReLU(),
           # nn.Dropout(p=0.2),
            nn.Linear(100, 1),
           # nn.view(-1)
        )
   
    def init_hidden(self):
        h = Variable(torch.zeros((2*self.num_layers, self.batch_size, self.hidden_size)))#.cuda()
        c = Variable(torch.zeros((2*self.num_layers, self.batch_size, self.hidden_size)))#.cuda()
        return h, c
    
    def forward(self, x):
        # print('inp size: ',x.size())   # x: Seq_len*Batch_size
        x = self.embedding(x) 
        self.hidden=self.init_hidden()  # 
        # print('hidden size: ',self.hidden[0].size())
        x, self.hidden = self.lstm(x, self.hidden) #x: seq_len * Batch_size* Batch_size
        # print('LSTM_X: ',x.size()) 
        x = self.hidden2out(x[-1])  # select the last output : 1*Batch_size
        # print('h2out_size: ',x[-1].size())
        out = torch.sigmoid(x)
        # print('out:', out.shape)
        return out.view(-1)

train method:

clf =LSTM_clf(embed_size=EMBEDDING_DIM, 
                  hidden_size=LSTM_HIDDEN_DIM, 
                  vocab_size=VOCAB_SIZE,
                  num_layers=2,
                  num_classes=2, 
                  batch_size=BATCH_SIZE,max_seq_len=MAX_SEQ_LEN)
clf_optimizer = optim.Adam(clf.parameters(),lr = CLF_LR)
clf_criterion = nn.BCELoss() #nn.CrossEntropyLoss()
best_loss = float('inf')

for epoch in range(1 , EPOCHS+1):
  print("EPOCH : " , epoch)
  loss ,acc ,rcl ,prc ,f1 ,auc ,fpr0 ,tpr0 ,fpr1 ,tpr1 = train_classifier(clf, train_dl, clf_criterion, clf_optimizer)
  print(f"LOSS: {loss:.7f}; ACCURACY: {acc*100:.7f}; RECALL: {rcl:.7f}; PRECISION: {prc:.7f}; F1-SCORE: {f1}; AUC: {auc}; ")
  print(f"FALSE POSITIVE RATES0: {fpr0}; TRUE POSSITIVE RATE0: {tpr0}; FALSE POSITIVE RATES1: {fpr1}; TRUE POSSITIVE RATE1: {tpr1};")
def clf_training(model, train_inputs, train_labels, optimizer, criterion,filename=''): #for one iteration
    # model.hidden = repackage_hidden(model.hidden)
    
    train_inputs, train_labels = Variable(train_inputs), Variable(train_labels).float()
    optimizer.zero_grad()
    output = model(train_inputs.long().t()) 
    loss = 0
    loss = criterion(output, train_labels)
    loss.backward(retain_graph=True)  
    optimizer.step()

    loss = loss.data.item()#[0]

    
    acc= accuracy_score((output>0.5).float().cpu(), train_labels.cpu())#.item()

    rcl= recall_score((output>0.5).float().cpu(), train_labels.cpu()).item()
    prc= precision_score((output>0.5).float().cpu(), train_labels.cpu()).item()

    f1= f1_score(train_labels.cpu() , (output>0.5).float().cpu()).item()
    
    fpr, tpr, thresholds = roc_curve(train_labels.cpu(), (output).float().cpu().data.numpy())

    auc_s = metrics.auc(fpr, tpr) #roc_auc_score(train_labels.cpu(), (output).float().cpu()) # metrics.auc(fpr, tpr)
    # auc = auc(recall, precision)
    
    return acc, rcl, prc, f1, auc_s , fpr, tpr, loss, (output>0.5).float()

def train_classifier(model, train_dl, criterion, optimizer,filename='', i=0): #for one epoch
    model.train()
    total_acc=0
    total_rcl=0
    total_prc=0
    total_f1=0
    total_auc=0
    total_loss_clf=0
    total_loss_gen=0
    total_fpr0=0
    total_tpr0=0
    total_fpr1=0
    total_tpr1=0
    total=len(train_dl)
    probs=[]
    
    for i, (train_inputs, train_labels) in tqdm_notebook(enumerate(train_dl), desc='Training', total=len(train_dl)):
        # print("ITER : " ,i)
        acc, rcl, prc, f1, auc_s,fpr,tpr, loss_clf,_= clf_training(model, train_inputs, train_labels, optimizer, criterion, filename)
        total_acc+=acc
        total_rcl+=rcl
        total_prc+=prc
        total_f1+=f1
        total_auc+=auc_s
        total_loss_clf+=loss_clf
        total_fpr0+=fpr[0]
        total_tpr0+=tpr[0]
        total_fpr1+=fpr[1]
        total_tpr1+=tpr[1]

    return total_loss_clf/total, total_acc/total ,total_rcl/total ,total_prc/total ,total_f1/total ,total_auc/total, total_fpr0, total_tpr0, total_fpr1, total_tpr1

evaluate model:

def evaluate_classifier(model, test_inputs, test_labels, criterion):
  loss = 0
  test_inputs, test_labels = Variable(test_inputs), Variable(test_labels).float()
  output = model(test_inputs.long().t()) 
  loss = criterion(output, test_labels)
  loss= loss.data.item()
  print(loss)
  print(output)
  print(test_labels)
  acc= accuracy_score((output>0.5).float().cpu(), test_labels.cpu()).item()
  print(acc)
  rcl= recall_score((output>0.5).float().cpu(), test_labels.cpu()).item()
  prc= precision_score((output>0.5).float().cpu(), test_labels.cpu()).item()
  f1= f1_score((output>0.5).float().cpu(), test_labels.cpu()).item()
  
  fpr, tpr, thresholds =roc_curve(test_labels.cpu(), (output).float().cpu().data.numpy())
  auc_s = metrics.auc(fpr, tpr)
  return acc, rcl, prc, f1, auc_s , fpr, tpr, loss, (output>0.5).float()


def test_classifier(model, test_dl, criterion, i=0):
  total_acc=0
  total_rcl=0
  total_prc=0
  total_f1=0
  total_auc=0
  total_loss_clf=0
  total_loss_gen=0
  total_fpr0=0
  total_tpr0=0
  total_fpr1=0
  total_tpr1=0
  print(len(test_dl))
  total=len(test_dl)
  probs=[]
  model.eval()
  with torch.no_grad():
    for i, (test_inputs, test_labels) in tqdm_notebook(enumerate(test_dl), desc='TEST', total=len(test_dl)):
      acc, rcl, prc, f1, auc_s,fpr,tpr, loss_clf,_= evaluate_classifier(model, test_inputs, test_labels, criterion)
      total_loss_clf+=loss_clf
      total_acc+=acc
      total_rcl+=rcl
      total_prc+=prc
      total_f1+=f1
      total_auc+=auc_s
      total_fpr0+=fpr[0]
      total_tpr0+=tpr[0]
      total_fpr1+=fpr[1]
      total_tpr1+=tpr[1]
      
  return total_loss_clf/total, total_acc/total ,total_rcl/total ,total_prc/total ,total_f1/total ,total_auc/total, total_fpr0, total_tpr0, total_fpr1, total_tpr1

hyperparameters:

MAX_SEQ_LEN = 20
BATCH_SIZE = 64
LSTM_HIDDEN_DIM = 128
CLF_LR = 0.0001 #5e-2
EMBEDDING_DIM = 100

This is a picture of some epochs of the training phase:

and this is a picture of the outputs of the model in evaluating phase. as you see, the model always predicts the same number in every batch and for every sample:

The dataset is from HSPAM with sampled 1000 spam tweets and 1000 non-spam tweets, so it’s balanced.

There can be multiple issues with code

  • First make sure there isn’t any nans in your code
  • Try playing the activation functions of your network
  • Try a different Optimizer and Loss functions
  • Epoch of 50 is too high, try 7

OK, I skimmed through you code, and I think(!) I spotted the issues.

The main problem is that you create nn.LSTM with batch_first=False. Because of this you transpose the input in line

output = model(test_inputs.long().t()) 

You shouldn’t do that as it messes up the semantics of the embedding step. So change that to

output = model(test_inputs.long()) 

Of course now your dimensions are off, so you need to transpose after the embedding step. So change

x = self.embedding(x)

to

 x = self.embedding(x).transpose(0, 1)

With this change, you should already see some progress. The last thing is that you output of the LSTM a bit sloppily. The output shape of x after

x, self.hidden = self.lstm(x, self.hidden)

is (seq_len, batch_size, num_directions*hidden_dim). The problem is that x[-1] gives you, for each sample, the last time step of the forward pass (which is what you want) but the first time step of the backward pass (this one is not useful to you).

I would recommend not using x but self.hidden. In this case you need to handle the number of LSTM layers though as the shape of self.hidden is (num_dir*num_layers, hidden_dim). You can try the following

...
x, self.hidden = self.lstm(x, self.hidden)
final_state = self.hidden[0].view(self.num_layers, 2, batch_size, self.hidden_size)[-1]
forward, backward = final_state[0], final_state[1]
final_state = torch.cat((forward, backward), 1)  # Concatenate both states
x = self.hidden2out(final_state)
...

For a complete code, you can check you my implementation here.