Hello Pytorch Community,
I am currently working on a discourse segmentation project and have hit a roadblock. I’ve developed a model to classify tokens as either being a discourse boundary or not. The model is built using PyTorch Lightning and BERT.
Issue: The main issue I’m facing is that the model consistently predicts 1s (boundaries), resulting in extremely low precision (<0.001) and the recall is 1.0. This indicates a high number of false positives.
Model Details: The model architecture is as follows:
- BERT for token embeddings.
- A dropout layer (0.5).
- A linear layer for classification.
I am using the F.logsigmoid
function for calculating log probabilities and F.nll_loss
for loss calculation. Additionally, I’ve added L2 regularization.
class BoundaryClassification(pl.LightningModule):
def __init__(self, bert):
super().__init__()
self.bert = bert
self.dropout = nn.Dropout(0.5)
self.segmentation_head = nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.segmentation_head(sequence_output)
return logits
def process_outputs(self, outputs):
predictions = torch.argmax(outputs, dim=-1)
y_pred = predictions.flatten().tolist()
return y_pred
def training_step(self, batch):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['edus']
logits = self(input_ids=input_ids, attention_mask=attention_mask)
log_probs = F.logsigmoid(logits)
loss = F.nll_loss(log_probs.view(-1), labels.view(-1), weight=self.pos_weight)
lambda_reg = 1e-4
l2_reg = sum(p.pow(2.0).sum() for p in self.parameters())
l2_reg = lambda_reg * l2_reg / 2
probs = torch.sigmoid(logits).squeeze(2)
y_pred = (probs > 0.5).long().flatten().tolist()
y_true = labels.cpu().flatten().detach().numpy().tolist()
precision, recall = calculate_precision_recall(y_true, y_pred)
pk = calculate_pk_score(y_true, y_pred)
wd = calculate_windowdiff(y_true, y_pred)
batch_size = len(labels)
self.log("train_precision", precision, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
self.log('train_recall', recall, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
self.log('train_pk_score', pk, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
self.log('train_windowdiff', wd, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
return loss
Specific Questions:
- Could there be an issue with how I’m processing the logits or setting up the loss function?
- Are there any specific techniques or adjustments in the model architecture that could help balance precision and recall?
- Any insights on how to effectively handle class imbalance in this context?
I appreciate any guidance or suggestions you can offer. Thank you for your time and help!