Tips to debug no improvement to test set precision recall

I have a Pytorch MLP model where test set precision-recall does not improve whatever I try. A few high level details about dataset and current observations

  • Number of negative samples (1) has low percentage (around 4% of total)
  • I need around 15-20% recall at 98% precision in this use case
  • Precision-Recall(PR) on Train set starts to hit requirement as num epochs and model complexity increases
  • But Test set Precision is always stuck around 35% whatever I try (dropout layers, assymetric weights (higher weight for an error with 1), higher % samples in test split, etc)
  • On test set beyond a point, the precision starts to drop a bit (concrete numbers below)
  • The samples include multiple entries from same log. So, to avoid data leakage, I create train/test set such that train and test set never a row with same value for column “log_id”
  • All features are numeric values. There are about 15 features and 1M samples

Questions:

  • Am I applying the assymetric weights (to deal with unbalanced dataset) correctly? This is my code
loss_function = nn.BCELoss(weight=assym_wt*targets + 1)
loss_function = loss_function.to(device)
loss = loss_function(outputs, targets)
  • Any things I can try other than dropout and assymetric weights (I will also try resampling next)?
  • Do you see any issues with my Pytorch code that could be causing these issues?
  • Am I applying dropout correctly in my MLP class below?
  • The code is also not using GPU even in machines with GPU. Any tips on how to debug this would be appreciated.

Pytorch MLP class

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, act_fn=nn.ReLU(), use_dropout=False, drop_rate=0.25):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.layers = nn.Sequential()
        if use_dropout:
            self.layers.append(nn.Dropout(p=drop_rate))
        self.layers.append(nn.Linear(self.input_size, self.hidden_size[0]))
        self.layers.append(act_fn)                
        for i in range(1, len(hidden_size)):
            if use_dropout:
                self.layers.append(nn.Dropout(p=drop_rate))
            self.layers.append(nn.Linear(self.hidden_size[i - 1], self.hidden_size[i]))
            self.layers.append(act_fn)
        
        if use_dropout:
            self.layers.append(nn.Dropout(p=drop_rate))
        self.layers.append(nn.Linear(self.hidden_size[-1], 1))
        self.layers.append(nn.Sigmoid())      

    def forward(self, x):
        return self.layers(x)

Dataset loading class

class MyDataset(Dataset):

  def __init__(self, x, y, use_gpu=False):
    x = x.astype(np.float32)
    self.x_train = torch.from_numpy(x)
    self.y_train = torch.from_numpy(y.values)
    if use_gpu:
        device = torch.device("cuda")
        self.x_train.to(device)
        self.y_train.to(device)
    # self.y_train = torch.LongTensor(y.values, dtype=torch.int)


  def __len__(self):
    return len(self.y_train)
  
  def __getitem__(self,idx):
    return self.x_train[idx],self.y_train[idx]

Train-Test set creation

# Only look at samples where sum across key columns is > 1000
df_input = df_input[df_input[key_cols].sum(axis=1) > 1000]

#Split using log-id so that test set does not have any data from a log used in train set splitter = GroupShuffleSplit(test_size=.20, n_splits=2, random_state = 7)
split = splitter.split(df_input, groups=df_input['log_id'])
train_inds, test_inds = next(split)
train = df_input.iloc[train_inds]
test = df_input.iloc[test_inds]



#Remove non-Feature columns to create train set
#Labels are 1 if issue% > 2 else 0
y_training = train['issue_pcnt'].apply(lambda issue_pcnt: 1 if issue_pcnt > 2 else 0)
y_testing = test['issue_pcnt'].apply(lambda issue_pcnt: 1 if issue_pcnt > 2 else 0)
X_training = train[train.columns[~train.columns.isin(['log_id', 'ds', 'issue_pcnt'])]]
X_testing = test[train.columns[~train.columns.isin(['log_id', 'ds', 'issue_pcnt'])]]

# Normalize feature columns and apply same normalization co-efficients to test data
scaler = StandardScaler()
X_training = scaler.fit_transform(X_training)
X_testing = scaler.transform(X_testing)

Training Loop

random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
torch.backends.cudnn.deterministic = True

clf = MLP(len(X_training[0]), hidden_size=[100, 100, 100, 100, 100])
#Move to GPU if available
use_gpu = torch.cuda.is_available()
device = torch.device('cuda' if use_gpu else 'cpu')

assym_wt = 0 #change to 20 give more weight on predicting a 1
# Define the loss function and optimizer
optimizer = torch.optim.Adam(clf.parameters(), lr=8e-4)
clf = clf.to(device)

# Run the training loop
for epoch in range(0, 150):
    # Set current loss value
    current_loss = 0.0
    dataset = MyDataset(X_training, y_training, use_gpu)
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_gpu else {}
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=10000, shuffle=True, **kwargs)

    # Iterate over the DataLoader for training data
    clf.train()  # set to train mode
    for i, data in enumerate(trainloader):
        # Get inputs
        inputs, targets = data
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = clf(inputs)

        # Compute loss
        targets = targets.float().unsqueeze(1)
        if assym_wt > 0:
           loss_function = nn.BCELoss(weight=None if assym_wt==0 else assym_wt*targets + 1)
           loss_function = loss_function.to(device)
        loss = loss_function(outputs, targets)

        # Perform backward pass
        loss.backward()

        # Perform optimization
        optimizer.step()

        # Print statistics
        current_loss += loss.item()
        if i % 20000 == 19999:
            print("Loss after mini-batch %5d: %.3f" % (i + 1, current_loss / 500))
            current_loss = 0.0

Model Evaluation

start_time = time.time()
clf.eval()  # set to eval mode
with torch.no_grad():
    test_preds = clf.forward(torch.from_numpy(X_testing.astype(np.float32)).to(device)).cpu()
    train_preds = clf.forward(torch.from_numpy(X_training.astype(np.float32)).to(device)).cpu()

print(f"Metric calc time is {time.time() - start_time}")

for (targets, preds, title) in [(y_training, train_preds, 'Training Data'), (y_testing, test_preds, 'Testing Data')]:
    print(f'\n***{title} Performance***')
    for thresh in np.arange(0.7, 0.95, 0.05):
        print(f"Thresh {thresh}: Precision: {precision_score(targets, preds > thresh)}, Recall {recall_score(targets, preds > thresh)}")

Sample output

***Training Data Performance***
Thresh 0.7: Precision: 0.9200912569914631, Recall 0.24288850465292483
Thresh 0.75: Precision: 0.9360716591349257, Recall 0.22535698327278378
Thresh 0.8: Precision: 0.9513260865672176, Recall 0.2064111281642803
Thresh 0.8500000000000001: Precision: 0.9651179183533724, Recall 0.18619858955180385
Thresh 0.9000000000000001: Precision: 0.9776447105788423, Recall 0.16176636294756475

***Testing Data Performance***
Thresh 0.7: Precision: 0.363302002451982, Recall 0.07313742266684217
Thresh 0.75: Precision: 0.3626991565135895, Recall 0.06367645123074898
Thresh 0.8: Precision: 0.3569584856958486, Recall 0.05461037251546663
Thresh 0.8500000000000001: Precision: 0.34896443923407583, Recall 0.044079899960510725
Thresh 0.9000000000000001: Precision: 0.3463246176615688, Recall 0.03465183625115177