Getting Cuda Out of Memory while running Longformer Model in Google Colab. Similar code using Bert is working fine

I am using the base model of Longformert. I havent added any additional layer. I am doing text classification task. With 100 rows i can train the model. But if i increase the size of dataset , it says cuda out of memory. I was previously using adam optimizer now using SGD and also batch size is 1. Please help. How should i train my model . My model :slight_smile:

class LongformerForMultiSequenceClassification(LongformerForSequenceClassification):

def __init__(self, config):

    super().__init__(config)

    

def forward(

    self,

    input_ids=None,

    attention_mask=None,

    global_attention_mask=None,

    token_type_ids=None,

    position_ids=None,

    inputs_embeds=None,

    labels=None,

    output_attentions=None,

):

    if global_attention_mask is None:

      global_attention_mask = torch.zeros_like(input_ids)

      # global attention on cls token

      global_attention_mask[:, 0] = 1

    outputs = self.longformer(

        input_ids,

        attention_mask=attention_mask,

        global_attention_mask=global_attention_mask,

        token_type_ids=token_type_ids,

        position_ids=position_ids,

        inputs_embeds=inputs_embeds,

        output_attentions=output_attentions,

    )

    sequence_output = outputs[0]

    logits = self.classifier(sequence_output)

    outputs = (logits,) + outputs[2:]

    



    if labels is not None:

      

      criterion = torch.nn.CrossEntropyLoss()

      loss = criterion(torch.sigmoid(logits.view(-1, self.num_labels)), labels)

      

      outputs = (loss,) + outputs

    return outputs  # (loss), logits, (hidden_states), (attentions)'''

optimizer :

from transformers import AdamW, get_linear_schedule_with_warmup
import torch.optim as optim

optimizer = optim.SGD(model.parameters(), lr=.001)

epochs = 5

scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=0,
num_training_steps=len(dataloader_train)*epochs)

Training :

from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss

for epoch in tqdm(range(1, epochs+1)):

model.train()



loss_train_total = 0

progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)

for batch in progress_bar:

      

    #take out inputs

    batch = tuple(b.to(device) for b in batch)

    

    inputs = {'input_ids':      batch[0],

              'attention_mask': batch[1],

              'labels':         batch[2],

             } 



    #insert the input into the model and get the result

    outputs = model(**inputs)

    # pooled_output of shape [batch_size, 768] with representations for the entire input sequences

    # sequence_output of shape [batch_size, max_seq_length, 768] with representations for each input token (in context).

    loss = outputs[0]

    #calculate loss

    loss_train_total += loss.item()

    #this will calculate the gradients

    loss.backward()

    # for preventening gradient explosion

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    #this will update the weights 

    optimizer.step()

    #optimizing learning rate

    scheduler.step()

    #this will empty the gradients from the previous iterations

    model.zero_grad()     

    

    progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})

             

#torch.save(model.state_dict(), f'/content/Gdrive/My Drive/finetuned_longformer_epoch_{epoch}.model')

#torch.save(model.state_dict(), f'checkpoint{epoch}.pth')

    

tqdm.write(f'\nEpoch {epoch}')



loss_train_avg = loss_train_total/len(dataloader_train)            

tqdm.write(f'Training loss: {loss_train_avg}')



val_loss, predictions, true_vals = evaluate(dataloader_validation)

val_f1 = f1_score_func(predictions, true_vals)

tqdm.write(f'Validation loss: {val_loss}')

tqdm.write(f'F1 Score (Weighted): {val_f1}')

If you GPU is raising an OOM issue in the first training iteration using a batch size of 1, you could try to apply torch.utils.checkpoint to trade compute for memory or reduce the model or input length (if possible).

I cannot see any obvious errors in your code.
Note that you can add code snippets by wrapping them into three backticks ```, which makes debugging easier. :wink: