torch 1.1.0
from transformers import BertPreTrainedModel, BertModel
#from torchcrf import CRF
#from torch.nn import log_softmax
from torch.nn import CrossEntropyLoss
NUM_LAYER = 3
class Bert_LSTM(BertPreTrainedModel):
"""
Defining the network here. We inherit our model from Bert Model and
add CRF layer as the output layer
"""
def __init__(self, config, hidden_size=256):
super(Bert_LSTM, self).__init__(config, hidden_size=256)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.bilstm = nn.LSTM(bidirectional=True, num_layers=NUM_LAYER, input_size=768, hidden_size=hidden_size)
self.classifier = nn.Linear(2*hidden_size, self.num_labels)
self.init_weights()
def init_hidden(self, batch_size):
return torch.zeros(NUM_LAYER*2, batch_size, 256).to(device), torch.zeros(NUM_LAYER*2, batch_size, 256).to(device)
def forward(self, input_ids, attn_masks, labels=None):
"""
Parameters:
input_ids : tokens ids of shape [Batch x Seq Length x Word emb dimension]
attn_masks : attn masks is used to ignore the padding sub tokens of shape [B x S]
labels : label ids of the tokens of shape [B x S]
Return:
if labels is None we return predicted label index and label wise score
if labels is not None then we return the loss
"""
batch_size = input_ids.size()[0]
#with torch.no_grad():
enc_out = self.bert(input_ids, attention_mask=attn_masks)
sequence_output = enc_out[0]
sequence_output = self.dropout(sequence_output)
sequence_output = sequence_output.permute(1, 0, 2).contiguous()
hidden_state = self.init_hidden(batch_size)
self.bilstm.flatten_parameters()
lstm_out, hidden_state = self.bilstm(sequence_output, hidden_state)
#SxBxD --> BxSxD
lstm_out = lstm_out.permute(1, 0, 2).contiguous()
logits = self.classifier(lstm_out)
#outputs = (logits)# + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attn_masks is not None:
active_loss = attn_masks.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss, logits)
return outputs # (loss), scores, (hidden_states), (attentions)
optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
# defininig the scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=t_total)
Epoch: 6
Current LR: 4.0625000000000005e-05
2020-04-22 07:57:26.632254 Step: 1 of 44041 Loss: 0.526913
2020-04-22 08:05:17.593988 Step: 5000 of 44041 Loss: 0.027799
2020-04-22 08:13:09.941345 Step: 10000 of 44041 Loss: 0.030528
2020-04-22 08:21:00.065158 Step: 15000 of 44041 Loss: 0.061899
2020-04-22 08:28:54.857574 Step: 20000 of 44041 Loss: 0.013886
2020-04-22 08:36:48.340908 Step: 25000 of 44041 Loss: 0.012853
2020-04-22 08:44:41.358208 Step: 30000 of 44041 Loss: 0.013785
2020-04-22 08:52:35.497869 Step: 35000 of 44041 Loss: 0.075335
2020-04-22 09:00:30.151845 Step: 40000 of 44041 Loss: 0.217842
Training Loss: 0.053613 for epoch 6
Epoch: 6
Epoch: 4%|▍ | 1/26 [1:21:53<34:07:10, 4913.24s/it]
Validation result:
processed 3220946 tokens with 170733 phrases; found: 172690 phrases; correct: 164607.
accuracy: 97.04%; (non-O)
accuracy: 99.52%; precision: 95.32%; recall: 96.41%; FB1: 95.86%
age: precision: 73.56%; recall: 94.85%; FB1: 82.86% 3480
date: precision: 97.35%; recall: 98.09%; FB1: 97.72% 134238
hospital_name: precision: 63.47%; recall: 78.55%; FB1: 70.21% 1979
id: precision: 94.39%; recall: 95.60%; FB1: 94.99% 2924
location: precision: 59.75%; recall: 88.03%; FB1: 71.19% 1538
org_name: precision: 49.23%; recall: 72.73%; FB1: 58.72% 325
person_name: precision: 93.58%; recall: 90.36%; FB1: 91.94% 27636
room_no: precision: 72.58%; recall: 97.89%; FB1: 83.36% 383
season: precision: 88.00%; recall: 85.44%; FB1: 86.70% 100
telephone_no: precision: 56.32%; recall: 83.05%; FB1: 67.12% 87
Epoch: 7
Current LR: 3.299549709361731e-05
2020-04-22 09:19:19.862900 Step: 1 of 44041 Loss: 0.002444
2020-04-22 09:27:15.515516 Step: 5000 of 44041 Loss: 0.013085
2020-04-22 09:35:08.498298 Step: 10000 of 44041 Loss: 0.011330
2020-04-22 09:43:04.788642 Step: 15000 of 44041 Loss: 0.031396
2020-04-22 09:51:02.563011 Step: 20000 of 44041 Loss: 0.012581
2020-04-22 09:59:03.283105 Step: 25000 of 44041 Loss: 0.011349
2020-04-22 10:07:01.185171 Step: 30000 of 44041 Loss: 0.012138
2020-04-22 10:14:59.467057 Step: 35000 of 44041 Loss: 0.011784
2020-04-22 10:22:52.307055 Step: 40000 of 44041 Loss: 0.172492
Training Loss: 0.106804 for epoch 7
Epoch: 7
Epoch: 8%|▊ | 2/26 [2:44:14<32:48:40, 4921.69s/it]
Validation result:
processed 3220946 tokens with 170733 phrases; found: 0 phrases; correct: 0.
accuracy: 0.00%; (non-O)
accuracy: 91.85%; precision: 0.00%; recall: 0.00%; FB1: 0.00%
age: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
date: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
hospital_name: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
id: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
location: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
org_name: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
person_name: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
room_no: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
season: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0
telephone_no: precision: 0.00%; recall: 0.00%; FB1: 0.00% 0