PyTorch converting LSTM code to accept batches gives no prediction power

I have an issue where if I feed data row by row, then my binary-classification LSTM model gets an AUC of 0.9+ within a few epochs. When I changed to code so that it accepts batches, the AUC gets stuck at 0.5 despite the loss decreasing.

I created a toy example of the LSTM model to try to test out this issue. I am suspecting that my model architecture is passing in the wrong information, because at some point it just predicts everything as “positive”, but I don’t know where. Code is below:

Import and helper functions:

import pandas as pd
import numpy as np
from sklearn.metrics import roc_curve, auc, roc_auc_score

import multiprocessing

import pickle
import ast 
import s3fs
import json

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils.rnn as rnn_utils

import random
from random import randint
import os
import math
import time

# returns the softmax output into a readable prediction
def get_max_prob_result(input, ix_to_word):
    return ix_to_word[get_index_of_max(input)] 

# calculates the roc_auc score for the predictions
def get_auroc(truth, pred):
    assert len(truth) == len(pred)
    auc_score = roc_auc_score(np.array(truth),np.array(pred))
    return auc_score
    
def reorder_list(list, new_index_list):
    new_list = []
    for index in new_index_list:
        new_list.append(list[index])
    return new_list

def grab_batch(batch_size):
    seq=[]
    freq=[]
    target=[]
    time = []
    for k in range(batch_size):
        tseq, tfreq, ttime, ttarget = generate_patient()
        seq.append(tseq)
        freq.append(tfreq)
        time.append(ttime)
        target.append(ttarget)
    
    return seq, freq, time, target

Function to randomly generate data with structure that matches my real use case (note that I created a rule to define when target is positive that the model will learn):

events_to_ix = {'<PAD>':0,'non':1,'othernon':2,'neutral':3,'trigger':4}

final_seq = []
final_freq = []
final_time = []
final_target = []

dict_keys = list(events_to_ix.keys())[1:]

def generate_patient():
    num_seq = randint(1,100)
    patient_seq = []
    patient_freq = []
    patient_time = []
    patient_target = 0

    
    final_seq = []
    final_freq = []
    final_time = []
    
    for i in range(num_seq):
        step_seq = []
        step_freq = []
        step_time = []
        seq_length = randint(1,10)
        for k in range(seq_length):
            event = random.choice(dict_keys)
            if events_to_ix[event] in step_seq:
                continue
            step_seq.append(events_to_ix[event])
            step_freq.append(randint(1, 17))
            step_time.append(randint(0,(5+seq_length-k)))

        patient_seq.append(step_seq)
        patient_freq.append(step_freq)
        patient_time.append(step_time)

    for index, item in enumerate(patient_seq[-1]): 
        if item == 4 and patient_freq[-1][index] > 15 and patient_time[-1][index] < 3:
            patient_target = 1
            break

    # Loop through each concet in each timestep
    for step_idx, step in enumerate(patient_seq):
        concepts = [0]*len(events_to_ix)
        frequencies = [0]*len(events_to_ix)
        times = [0]*len(events_to_ix)
        for event_idx, event in enumerate(step):

            # Convert the textual concepts into their index representation
            concepts[event] = event

            # Append the frequencies and time encodings to their appropriate position in the list of 0's
            frequencies[event] = patient_freq[step_idx][event_idx]
            times[event] = patient_time[step_idx][event_idx]
            
        final_seq.append(concepts)
        final_freq.append(frequencies)
        final_time.append(times)
        #final_static.append(patient_static)
        
    final_seq = torch.LongTensor(final_seq)
    final_freq = torch.FloatTensor(final_freq).view(-1,len(events_to_ix),1)
    final_time = torch.FloatTensor(final_time).view(-1,len(events_to_ix),1)
    #final_static = torch.FloatTensor(final_static)
    
    return final_seq, final_freq, final_time, patient_target

The LSTM model:

# Class containing the LSTM model initialization and feed-forward logic
class LSTMClassifier(nn.Module):
    # LSTM initialization
    def __init__(self, embedding_dim, hidden_dim, vocab_size, label_size):
        super(LSTMClassifier, self).__init__()

        # Setting the hidden layer dimension of the LSTM
        self.hidden_dim = hidden_dim
        # Initializing the embedding layer
        self.embeddings = nn.Embedding(vocab_size, embedding_dim-2)
        # Initializing the LSTM layer with one hidden layer 
        self.lstm = nn.LSTM(((embedding_dim*vocab_size)), hidden_dim, num_layers = 1, batch_first=False)
        # Initializing linear linear that takes the hidden layer output
        self.hidden2label = nn.Linear(hidden_dim, label_size)


    # Defining the hidden state of the LSTM
    def init_hidden(self,batch_size):
        # the first is the hidden h
        # the second is the cell  c
        return [autograd.Variable(torch.zeros(1,batch_size, self.hidden_dim).cuda()),
                autograd.Variable(torch.zeros(1,batch_size, self.hidden_dim).cuda())]

    # Defining the feed forward logic of the LSTM. It contains:
    # 1. The embedding layer
    # 2. The LSTM layer with one hidden layer
    # 3. The softmax layer
    def forward(self, seq, freq, time_data, seq_lengths):

        # Grab the mini-batch length and max sequence length (pre-ordered)
        # (need to do this in the forward logic because of data parallelism and how the GPU's will split up the batch)
        sequence_length = seq.size()[1]
        batch_length = seq.size()[0]
        
        # reset the LSTM hidden state. 
        # Must be done before you run a new batch. Otherwise the LSTM will treat a new batch as a continuation of a sequence
        self.hidden = self.init_hidden(batch_length)
        
        # This is the pass to the embedding layer. 
        # The sequence is of dimension N and the output is N x Demb
        embeds = self.embeddings(seq)

        # Concatenate the embedding output with the time and frequency vectors
        embeds = torch.cat((embeds,freq), dim=3)
        embeds = torch.cat((embeds,time_data), dim=3)

        # Because the LSTM excepts a dimension of (sequence length, batch size, feature size), and we have (batch size, seq length, feature size),
        # we need to switch the first and second dimension so that we get the correct input format
        embeds = torch.transpose(embeds, 0, 1)

        # Flatten the embedding dimension so that the input to the LSTM remains 3D rather than 4D
        x = embeds.view(sequence_length, batch_length, -1) 

        # pack the padded sequence so that paddings are ignored
        packed_x = torch.nn.utils.rnn.pack_padded_sequence(x, seq_lengths, batch_first=False)

        # Feed to the LSTM layer
        lstm_out, self.hidden = self.lstm(packed_x, self.hidden)

        # Feed the last layer of the LSTM into the linear layer
        y = self.hidden2label(self.hidden[0][-1])
        
        # Produce the softmax probabilities
        log_probs = F.log_softmax(y)
        
        return log_probs

Function to run a single epoch:

def train_epoch(model, loss_function, optimizer,batch_size, i):
    # Set model to training mode and initialize variables
    model.train()
    avg_loss = 0.0
    count = -1
    truth_res = []
    pred_res = []

    # Group the dataframe into dataframe chunks of length batch size and loop through each batch
    for j in range(1000):
        count += 1
        
        seq, freq, time_data, target = grab_batch(batch_size)
        
        # Sort the batches by descending size
        final_seq_ordered = sorted(enumerate(seq), key=lambda x: len(x[-1]), reverse=True) 
    
        # Grab the original indices from final_seq and see how they are now ordered
        final_seq_indices = [item[0] for item in final_seq_ordered]
        
        # Grab the actual values from the tupled master list of concepts
        seq = [item[1] for item in final_seq_ordered]
        freq = reorder_list(freq, final_seq_indices)
        time_data = reorder_list(time_data, final_seq_indices)
        target = reorder_list(target, final_seq_indices)
        
        # Grab the list of lengths of sequences, for the purpose of packing the padded sequenes
        seq_lengths = torch.LongTensor(list(map(len, seq)))
        
        # Grab the targets into a list and append it into the truth_res list in order to measure AUC performance
        truth_res.extend(target)
        
        # Pad the sequences
        seq = rnn_utils.pad_sequence(seq, batch_first = True)
        freq = rnn_utils.pad_sequence(freq, batch_first = True)
        time_data = rnn_utils.pad_sequence(time_data, batch_first = True)

        # Put the padded sequences into Variable and Cuda cores
        seq = autograd.Variable(seq.cuda())
        freq = autograd.Variable(freq.cuda())
        time_data = autograd.Variable(time_data.cuda())
        target = autograd.Variable(torch.LongTensor(target).cuda())
        
        # Feed the tensor Variables into the model
        pred = model(seq,freq,time_data,seq_lengths)
        # Append the predictions into a list for future AUC evaluation
        pred_label = pred.detach().max(1)[1].cpu().numpy()
        pred_res.extend(pred_label)

        # Reset the model gradient
        model.zero_grad()
        # Compute the loss
        loss = loss_function(pred, target)
        # Backpropagate
        loss.backward()
        # Update weights
        optimizer.step()
        
        # Computes the average loss
        avg_loss += loss.detach().item()
        
    # Computes the AUC score
    auc_score = get_auroc(truth_res, pred_res)     
    avg_loss /= (1000/batch_size)
    print('epoch: %d done! \n train avg_loss:%g , auc:%g' % (i, avg_loss, auc_score))

Main training loop:

#############################
### Set hyper parameters ###
############################
EMBEDDING_DIM = 32
HIDDEN_DIM = 50
EPOCH = 10
BATCH_SIZE = 16
best_val_auc = 0.5

model = LSTMClassifier(embedding_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM, vocab_size=(len(events_to_ix)), label_size=2)
model = torch.nn.DataParallel(model.cuda())

weights = [(26/1000), 1]
class_weights = torch.FloatTensor(weights).cuda()
loss_function = nn.NLLLoss(weight=class_weights,reduction="sum",ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.001)

no_up = 0

#####################################################
### Set loop to determine number of EPOCHs to run ###
#####################################################
for i in range(EPOCH):
    #############################################
    ### Run the training on the training data ###
    #############################################
    print('epoch: %d start!' % i)
    start = time.time()
    
    # Perform the training on the epoch
    train_epoch(model,loss_function, optimizer,BATCH_SIZE, i)
    
    print("1 epoch length of time")
    print(time.time() - start)