I have a very big data. I am using the following code to train a model. The dataset is iterable. Some of the batches I get different size. Is there any solution to avoid different batch size?
max_length = 512
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(examples["tokens"], padding= 'max_length', truncation=True, max_length=max_length, is_split_into_words=True )
tokenized_inputs["labels"] = examples["ner_tags"]
return tokenized_inputs
traindataset = dataset_iter.map(tokenize_and_align_labels,batched=True, remove_columns=["split","tokens", "ner_tags","id"])
evaldataset = evaldataset_iter.map(tokenize_and_align_labels,batched=True, remove_columns=["split","tokens", "ner_tags","id"])
metric = evaluate.load("seqeval")
prevScore = 100
since = time.time()
for epoch in range(num_epochs):
model.train().to(device)
for i, batch in enumerate(tqdm(train_dataloader)):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs[0]
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i % 10 == 0:
print(f"loss: {loss}")
model.eval().to(device)
for j, batch in enumerate(tqdm(eval_dataloader)):
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
predictions = predictions.to(torch.int).tolist()[0]
labels = batch['labels'].to(torch.int).tolist()[0]
true_predictions = [
[id2label[prediction] for prediction, label in zip(predictions, labels) if label != -100]
]
true_labels = [
[id2label[label] for prediction, label in zip(predictions, labels) if label != -100]
]
metric.add_batch(predictions=true_predictions, references=true_labels)
loss = outputs[0]
results = metric.compute(predictions=true_predictions, references=true_labels)
if(prevScore<results['skill']['precision']):
prevScore=results['precision']
highestMetrics=results
torch.save(model, 'iq4roberta.pt')
else:
orig_stdout = sys.stdout
print("Highest Till now is",file = orig_stdout)
print(results)
time_elapsed = time.time() - since
print("Time",time_elapsed)