Problem with BI-LSTM CRF model for Punctuation restoration

Hello everyone,

I changed the code in this tutorial so it would work for Punctuation restoration (only Periods and Commas for now) instead of NER.

The issue is: After the training, I get good results (Precision, Recall and F1-score are all nearly 1), what should mean that the model is trained well. But after the testing, I get 0.00’s in all 3 metrics. I trained the model with datasets of various sizes (from 2000 lines to 3000 lines)

I would be very happy if you could give me any advice on what should should I change. I wonder if I am even testing the model the right way or if I am computing the metrics right. thank you very much in advance.

EMBEDDING_DIM = 5
HIDDEN_DIM = 4

training_data = [( train.split() , tags_train.split() )]
testing_data = [( test.split() , tags_test.split() )]

word_to_ix = {}
for sentence, tags in training_data:
        for word in sentence:
                if word not in word_to_ix:
                        word_to_ix[word] = len(word_to_ix)

tag_to_ix = {"W": 0, "P": 1, "C": 2, START_TAG: 3, STOP_TAG: 4}
model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)
optimizer = optim.SGD(model.parameters(), lr=0.02, weight_decay=0.1)

for epoch in range(20):
  for sentence, tags in training_data:
    model.zero_grad()
    sentence_in = prepare_sequence(sentence, word_to_ix)
    targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)
    loss = model.neg_log_likelihood(sentence_in, targets)
    loss.backward()
    optimizer.step()

word_to_ix = {}
for sentence, tags in testing_data:
	for word in sentence:
		if word not in word_to_ix:
			word_to_ix[word] = len(word_to_ix)

for sentence, tags in testing_data:
	test_targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)
with torch.no_grad():
	test_precheck_sent = prepare_sequence(testing_data[0][0], word_to_ix)

var = model(test_precheck_sent)
y_true = np.array(test_targets)
y_pred = np.array(var[1])
print(metrics.confusion_matrix(y_true, y_pred), "\n")
print(metrics.classification_report(y_true, y_pred, digits=3))

PS: I do not have much experience in both python and pytorch so that is why my code is quite messy and I know it could be writted much more effective, but for now I just want the code to work somehow.

You should have just one word_to_ix, shared by both train and test data. word_to_ix should not be modified.

2 Likes

Thank you very much, it helped!

1 Like

Can you please explain how to use this for a text file in which we have to do punctuation restoration