Model performance degrades after inferring first batch

Hello,
I fine-tuned a model (Roberta + CRF) on the task of sequence labelling, evaluated it, and the performance on the evaluation set was good (0.92 and above, accuracy, recall and F1).
However, when I use the model for inference, it performs very well on the first batch but then the performance degrades, very badly.

I have tried to run the inference on the same 4 batches multiple time, and it always give near perfect results for the first batch but then meaningless results on the rest. I shuffled the sentences, and the model always had correct predictions only for the first batch.
Even if it guessed it right on the first run for say sentences with ids 1-8 if in the second run, if those sentences are not in the first batch, the model performs very poorly.

I Tested the following solutions:
1- Checked that the model was in eval()
2- Reset the hidden_states after every layer
3- cleared gpu cache after every batch.
4- set a constant seed for all the layers (This solved the problem of different results for the same sentence in different batches, but nothing else)

The code looks like this:

def annotate_dataset(tokenizer, classifier, dataloader, max_length, device, threshold=None):

# Transferring model to device
classifier.to(device)
# Evaluation
classifier.eval()

predictions = []
annotated_results = []
with torch.no_grad():
    processed_batches = 0
    for batch in tqdm(dataloader):
        batch = [line.rstrip() for line in batch]
        tokenized_input = tokenizer.batch_encode_plus(batch,
                                                      add_special_tokens=True,
                                                      truncation=True,
                                                      padding=True,
                                                      max_length=max_length,
                                                      return_offsets_mapping=True
                                                      )

        input_ids = torch.tensor(tokenized_input.input_ids).to(device)
        attention_mask = torch.tensor(tokenized_input.attention_mask).to(device)

        output = classifier(input_ids, attention_mask, threshold=threshold)

        predictions.extend(output)

        results = format_result(batch,
                                tokenized_input.input_ids,
                                tokenized_input.offset_mapping,
                                predictions
                                )

        annotated_results.extend(results)

        processed_batches += 1
        if processed_batches % 50 == 0:
                # Every 50 batches write results to file 

def forward_pass()
outputs = self.roberta(input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(outputs[0])
emissions = self.hidden2tag(sequence_output)

predictions = self.crf.decode(emissions, mask=attention_mask.byte())
probabilities = F.softmax(emissions, dim=-1)

# Extract the confidence for the predicted label for each token
confident_predictions = []
# for token in prediciton:
#   if confiden > threshold:
#       set_label

return confident_predictions

Does someone have an idea why this is happening?
Thank you.

Do you see the same effect for random inputs? If so, could you post the model definition as well as the input shapes to this model?

I don’t quite understand what you mean with same effect for random inputs.

The model definition is as follows:

class RobertaCRF(nn.Module):
    def __init__(self, model_name, num_labels):
        super(RobertaCRF, self).__init__()
        self.num_labels = num_labels
        self.roberta = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.hidden2tag = nn.Linear(self.roberta.config.hidden_size, self.num_labels)
        self.crf = CRF(self.num_labels, batch_first=True)

The inputs have different sizes the shape for hidden layers [batch size, sequence length, 1024]
and the output is one of two classes (

Hello,

I finally figured it out. I had a mistake in the labels / tokens alignment. Thank you for your help @ptrblck