I’m getting the following error when using BERT with BiLSTM (my batch_size on BERT is 26). I want to concat last 4 hidden layer of BERT then feed it to BiLSTM. Here is my model:
from transformers import BertPreTrainedModel, BertModel
import torch.nn as nn
import torch
import torch.nn.functional as F
class BERT(BertPreTrainedModel):
def __init__(self, config):
super(BERT, self).__init__(config)
self.device = config.device
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(0.1)
self.lstm = nn.LSTM(input_size=config.hidden_size * 4, hidden_size=500, num_layers=3, dropout=0.5, bidirectional=True)
self.qa_outputs = nn.Linear(500*2, config.num_labels)
self.weight_class = config.weight_class
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
with torch.no_grad():
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
cls_output = torch.cat((outputs[2][-4][:,0, ...],outputs[2][-3][:,0, ...], outputs[2][-2][:,0, ...], outputs[2][-1][:,0, ...]),-1)
cls_output = self.lstm(cls_output.unsqueeze(0))[0]
logits = self.qa_outputs(cls_output)
return logits
def loss(self, input_ids, attention_mask, token_type_ids, label):
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
cls_output = torch.cat((outputs[2][-4][:,0, ...],outputs[2][-3][:,0, ...], outputs[2][-2][:,0, ...], outputs[2][-1][:,0, ...]),-1)
cls_output = self.lstm(cls_output.unsqueeze(0))[0]
logits = self.qa_outputs(cls_output)
target = label
loss = F.cross_entropy(logits, target)
predict_value = torch.max(logits, 1)[1]
list_predict = predict_value.cpu().numpy().tolist()
list_target = target.cpu().numpy().tolist()
return loss, list_predict, list_target
Really don’t know how to debug this. Any solution for this error. Thanks in advance.