F1 score suddenly drops to zero while Training Bert + LSTM NER model

torch 1.1.0

from transformers import BertPreTrainedModel, BertModel
#from torchcrf import CRF
#from torch.nn import log_softmax
from torch.nn import CrossEntropyLoss
NUM_LAYER = 3
class Bert_LSTM(BertPreTrainedModel):
    """
        Defining the network here. We inherit our model from Bert Model and 
        add CRF layer as the output layer
    """
    def __init__(self, config, hidden_size=256):
        super(Bert_LSTM, self).__init__(config, hidden_size=256)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.bilstm = nn.LSTM(bidirectional=True, num_layers=NUM_LAYER, input_size=768, hidden_size=hidden_size)
        self.classifier = nn.Linear(2*hidden_size, self.num_labels)
        self.init_weights()    
            
    def init_hidden(self, batch_size):
        return torch.zeros(NUM_LAYER*2, batch_size, 256).to(device), torch.zeros(NUM_LAYER*2, batch_size, 256).to(device)
    
    def forward(self, input_ids, attn_masks, labels=None): 
        """
            Parameters:
                input_ids : tokens ids of shape [Batch x Seq Length x Word emb dimension]
                attn_masks : attn masks is used to ignore the padding sub tokens of shape [B x S]
                labels : label ids of the tokens of shape [B x S]
            Return:
                if labels is None we return predicted label index and label wise score
                if labels is not None then we return the loss
        """
        batch_size = input_ids.size()[0]
        #with torch.no_grad():
        enc_out = self.bert(input_ids, attention_mask=attn_masks)
        sequence_output = enc_out[0]
        sequence_output = self.dropout(sequence_output) 
        sequence_output =  sequence_output.permute(1, 0, 2).contiguous()
        hidden_state = self.init_hidden(batch_size)
        self.bilstm.flatten_parameters()
        lstm_out, hidden_state = self.bilstm(sequence_output, hidden_state)
        #SxBxD --> BxSxD
        lstm_out = lstm_out.permute(1, 0, 2).contiguous()
        logits = self.classifier(lstm_out)
        
        #outputs = (logits)# + outputs[2:]  # add hidden states and attention if they are here
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attn_masks is not None:
                active_loss = attn_masks.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss, logits)

        return outputs  # (loss), scores, (hidden_states), (attentions)

optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
# defininig the scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=t_total)

Epoch: 6
Current LR: 4.0625000000000005e-05
2020-04-22 07:57:26.632254 Step: 1 of 44041 Loss: 0.526913
2020-04-22 08:05:17.593988 Step: 5000 of 44041 Loss: 0.027799
2020-04-22 08:13:09.941345 Step: 10000 of 44041 Loss: 0.030528
2020-04-22 08:21:00.065158 Step: 15000 of 44041 Loss: 0.061899
2020-04-22 08:28:54.857574 Step: 20000 of 44041 Loss: 0.013886
2020-04-22 08:36:48.340908 Step: 25000 of 44041 Loss: 0.012853
2020-04-22 08:44:41.358208 Step: 30000 of 44041 Loss: 0.013785
2020-04-22 08:52:35.497869 Step: 35000 of 44041 Loss: 0.075335
2020-04-22 09:00:30.151845 Step: 40000 of 44041 Loss: 0.217842
Training Loss: 0.053613 for epoch 6
Epoch: 6
Epoch: 4%|▍ | 1/26 [1:21:53<34:07:10, 4913.24s/it]
Validation result:

processed 3220946 tokens with 170733 phrases; found: 172690 phrases; correct: 164607.
accuracy: 97.04%; (non-O)
accuracy: 99.52%; precision: 95.32%; recall: 96.41%; FB1: 95.86%
age: precision: 73.56%; recall: 94.85%; FB1: 82.86% 3480
date: precision: 97.35%; recall: 98.09%; FB1: 97.72% 134238
hospital_name: precision: 63.47%; recall: 78.55%; FB1: 70.21% 1979
id: precision: 94.39%; recall: 95.60%; FB1: 94.99% 2924
location: precision: 59.75%; recall: 88.03%; FB1: 71.19% 1538
org_name: precision: 49.23%; recall: 72.73%; FB1: 58.72% 325
person_name: precision: 93.58%; recall: 90.36%; FB1: 91.94% 27636
room_no: precision: 72.58%; recall: 97.89%; FB1: 83.36% 383
season: precision: 88.00%; recall: 85.44%; FB1: 86.70% 100
telephone_no: precision: 56.32%; recall: 83.05%; FB1: 67.12% 87

Epoch: 7
Current LR: 3.299549709361731e-05
2020-04-22 09:19:19.862900 Step: 1 of 44041 Loss: 0.002444
2020-04-22 09:27:15.515516 Step: 5000 of 44041 Loss: 0.013085
2020-04-22 09:35:08.498298 Step: 10000 of 44041 Loss: 0.011330
2020-04-22 09:43:04.788642 Step: 15000 of 44041 Loss: 0.031396
2020-04-22 09:51:02.563011 Step: 20000 of 44041 Loss: 0.012581
2020-04-22 09:59:03.283105 Step: 25000 of 44041 Loss: 0.011349
2020-04-22 10:07:01.185171 Step: 30000 of 44041 Loss: 0.012138
2020-04-22 10:14:59.467057 Step: 35000 of 44041 Loss: 0.011784
2020-04-22 10:22:52.307055 Step: 40000 of 44041 Loss: 0.172492
Training Loss: 0.106804 for epoch 7
Epoch: 7
Epoch: 8%|▊ | 2/26 [2:44:14<32:48:40, 4921.69s/it]
Validation result:

processed 3220946 tokens with 170733 phrases; found: 0 phrases; correct: 0.
accuracy: 0.00%; (non-O)
accuracy: 91.85%; precision: 0.00%; recall: 0.00%; FB1: 0.00%
age: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
date: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
hospital_name: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
id: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
location: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
org_name: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
person_name: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
room_no: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
season: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
telephone_no: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0

This is perhaps due to exploding gradients as it commonly happens in LSTMs. Do you use gradient clipping?

Please share me a code snippet on how to use gradient clipping. As per my knowledge exploding/vanishing gradients issue arises only on RNN. Does it happen to LSTM too?

torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

You can clip grads like that. AFAIK exploding gradients do happen in LSTMs though Im not an expert on them. Id suggest you to try clipping anyway and see if it works.

1 Like