Need Help with Improving Precision in Discourse Boundary Detection Model

Hello Pytorch Community,

I am currently working on a discourse segmentation project and have hit a roadblock. I’ve developed a model to classify tokens as either being a discourse boundary or not. The model is built using PyTorch Lightning and BERT.

Issue: The main issue I’m facing is that the model consistently predicts 1s (boundaries), resulting in extremely low precision (<0.001) and the recall is 1.0. This indicates a high number of false positives.

Model Details: The model architecture is as follows:

  • BERT for token embeddings.
  • A dropout layer (0.5).
  • A linear layer for classification.

I am using the F.logsigmoid function for calculating log probabilities and F.nll_loss for loss calculation. Additionally, I’ve added L2 regularization.

class BoundaryClassification(pl.LightningModule):
    def __init__(self, bert):
        super().__init__()
        self.bert = bert
        self.dropout = nn.Dropout(0.5)
        self.segmentation_head = nn.Linear(self.bert.config.hidden_size, 1)

        
    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.segmentation_head(sequence_output)

        return logits

    def process_outputs(self, outputs):
        predictions = torch.argmax(outputs, dim=-1)
        y_pred = predictions.flatten().tolist()
   
        return y_pred

    def training_step(self, batch):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['edus']
        logits = self(input_ids=input_ids, attention_mask=attention_mask)

        log_probs = F.logsigmoid(logits)

        loss = F.nll_loss(log_probs.view(-1), labels.view(-1), weight=self.pos_weight)

        lambda_reg = 1e-4
        l2_reg = sum(p.pow(2.0).sum() for p in self.parameters())
        l2_reg = lambda_reg * l2_reg / 2


        probs = torch.sigmoid(logits).squeeze(2)

        y_pred = (probs > 0.5).long().flatten().tolist()
        y_true = labels.cpu().flatten().detach().numpy().tolist()
        precision, recall = calculate_precision_recall(y_true, y_pred)
        pk = calculate_pk_score(y_true, y_pred)
        wd = calculate_windowdiff(y_true, y_pred)

        batch_size = len(labels)
        self.log("train_precision", precision, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
        self.log('train_recall', recall, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
        self.log('train_pk_score', pk, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
        self.log('train_windowdiff', wd, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)

        return loss

Specific Questions:

  1. Could there be an issue with how I’m processing the logits or setting up the loss function?
  2. Are there any specific techniques or adjustments in the model architecture that could help balance precision and recall?
  3. Any insights on how to effectively handle class imbalance in this context?

I appreciate any guidance or suggestions you can offer. Thank you for your time and help!