How to accomodate larger batches?

I am trying to train BART-medium but I can only use batch_size=1 on google colab. How can the following code be modified to train with larger batches on google colab?

for epoch_i in range(0, epochs):

    # ========================================

    #               Training

    # ========================================

    print("")

    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))

    print('Training...')

    t0 = time.time()

    total_train_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):

        b_input_ids = batch['prompt'].to(device)

        b_labels = batch['label'].to(device)

        b_masks = batch['attention_mask'].to(device)

        model.zero_grad()

        outputs = model(  b_input_ids,

                          labels=b_labels,

                          attention_mask = b_masks

                        )

        loss = outputs[0]

        batch_loss = loss.item()

        total_train_loss += batch_loss

        # Get sample every x batches.

        if step % sample_every == 0 and not step == 0:

            elapsed = format_time(time.time() - t0)

            print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), batch_loss, elapsed))

            model.eval()

            sample_outputs = model.generate(

                                    bos_token_id=random.randint(1,30000),

                                    do_sample=True,

                                    top_k=50,

                                    max_length = 200,

                                    top_p=0.95,

                                    num_return_sequences=1

                                )

            for i, sample_output in enumerate(sample_outputs):

                  print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))

            model.train()

        loss.backward()

        optimizer.step()

        scheduler.step()

    # Calculate the average loss over all of the batches.

    avg_train_loss = total_train_loss / len(train_dataloader)

    # Measure how long this epoch took.

    training_time = format_time(time.time() - t0)

    print("")

    print("  Average training loss: {0:.2f}".format(avg_train_loss))

    print("  Training epoch took: {:}".format(training_time))

    # ========================================

    #               Validation

    # ========================================

    print("")

    print("Running Validation...")

    t0 = time.time()

    model.eval()

    total_eval_loss = 0

    nb_eval_steps = 0

    # Evaluate data for one epoch

    for batch in validation_dataloader:

        b_input_ids = batch['prompt'].to(device)

        b_labels = batch['label'].to(device)

        b_masks = batch['attention_mask'].to(device)

        with torch.no_grad():

            outputs  = model(b_input_ids,

#                            token_type_ids=None,

                             attention_mask = b_masks,

                            labels=b_labels)

            loss = outputs[0]

        batch_loss = loss.item()

        total_eval_loss += batch_loss

    avg_val_loss = total_eval_loss / len(validation_dataloader)

    validation_time = format_time(time.time() - t0)

    print("  Validation Loss: {0:.2f}".format(avg_val_loss))

    print("  Validation took: {:}".format(validation_time))

    # Record all statistics from this epoch.

    training_stats.append(

        {

            'epoch': epoch_i + 1,

            'Training Loss': avg_train_loss,

            'Valid. Loss': avg_val_loss,

            'Training Time': training_time,

            'Validation Time': validation_time

        }

    )

print("")

print("Training complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))