Hello, I’m trying to create a custom Bert model by adding a bilstm layer on top of a pretrained Bert model. I have the below code snippet that is causing problems :
class BertClassifier(nn.Module):
"""Bert Model for Classification Tasks."""
def __init__(self, freeze_bert=False):
super(BertClassifier, self).__init__()
# Specify hidden size of BERT, hidden size of our classifier, and number of labels
D_in, H, D_out = 768, 50, 2
# Instantiate BERT model
self.bert = BertModel.from_pretrained('bert-base-multilingual-uncased')
self.lstm = nn.LSTM(D_in, H, batch_first=True, bidirectional=True)
self.linear = nn.Linear(H*2 , D_out)
# Freeze the BERT model
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
# Feed input to BERT
outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)
sequence_output = outputs[0]
print("sequence_output size", sequence_output.size())
sequence_output, _ = self.lstm(sequence_output)
print("lstm size", sequence_output.size())
linear_output = self.linear(sequence_output)
print("linear_output size", linear_output.size())
return linear_output
def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):
"""Train the BertClassifier model."""
# Start training loop
print("Start training...\n")
for epoch_i in range(epochs):
# Print the header of the result table
print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
print("-" * 70)
# Measure the elapsed time of each epoch
t0_epoch, t0_batch = time.time(), time.time()
# Reset tracking variables at the beginning of each epoch
total_loss, batch_loss, batch_counts = 0, 0, 0
# Put the model into the training mode
model.train()
# For each batch of training data...
for step, batch in enumerate(train_dataloader):
batch_counts += 1
# Load batch to GPU
b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)
# Zero out any previously calculated gradients
model.zero_grad()
# Perform a forward pass. This will return logits.
logits = model(b_input_ids, b_attn_mask)
# Compute loss and accumulate the loss values
loss = loss_fn(logits, b_labels)
batch_loss += loss.item()
total_loss += loss.item()
# Perform a backward pass to calculate gradients
loss.backward()
# Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Update parameters and the learning rate
optimizer.step()
scheduler.step()
# Print the loss values and time elapsed for every 20 batches
if (step % 20 == 0 and step != 0) or (step == len(train_dataloader) - 1):
# Calculate time elapsed for 20 batches
time_elapsed = time.time() - t0_batch
# Print training results
print(
f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")
# Reset batch tracking variables
batch_loss, batch_counts = 0, 0
t0_batch = time.time()
# Calculate the average loss over the entire training data
avg_train_loss = total_loss / len(train_dataloader)
print("-" * 70)
#Evaluation
if evaluation == True:
# After the completion of each training epoch, measure the model's performance
# on our validation set.
val_loss, val_accuracy = evaluate(model, val_dataloader)
# Print performance over the entire training data
time_elapsed = time.time() - t0_epoch
print(
f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
print("-" * 70)
print("\n")
print("Training complete!")
So when i run the code, it returns the error Expected target size (32, 2), got torch.Size([32])
on line loss = loss_fn(logits, b_labels)
( in the training function). These are the sizes :
sequence_output size torch.Size([32, 64, 768])
lstm size torch.Size([32, 64, 100])
linear_output size torch.Size([32, 64, 2])
I tried reshaping the linear output with linear_output = linear_output.view(batch_size, 2)
but after that it threw the error shape '[32, 2]' is invalid for input of size 4096
on line linear_output = linear_output.view(batch_size, 2)
Any help or advice will be much appeciated.