CUDA Running out of memory after a few batches in an epoch

This is my training function:

def train():

  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  model.to(device)
  model.train()
  optim = torch.optim.AdamW(model.parameters(), lr=5e-5)

  for epoch in range(10):
      with tqdm(dataloader, unit=" batch", leave=True, position=0) as tepoch:
        for i, data in enumerate(tepoch):
            inputs = tokenizer(data['text'], padding=True, truncation=True, return_tensors='pt')
            lm_labels = inputs['input_ids'].to(device)
            sarcasm_labels = data['label'].to(device)
            masked_tokens = mask_inputs(inputs['input_ids']).to(device)
            attention_mask = inputs['attention_mask'].to(device)
            
            outputs = model(masked_tokens, attention_mask=attention_mask, labels=lm_labels, next_sentence_label=sarcasm_labels)

            optim.zero_grad()
            # extract loss
            loss = outputs.loss
            # calculate loss for every parameter that needs grad update
            loss.backward()
            # update parameters
            optim.step()
            # print relevant info to progress bar
            tepoch.set_description(f'Epoch {epoch}')
            tepoch.set_postfix(loss=loss.detach().item())
            torch.cuda.empty_cache()

The training begins fine, however after a few batches in the first epoch itself, I get the Runtime Error of Cuda being out of memory.
The model is a Huggingface BertForPretraining Model.
Can someone help me with this?

Edit: For some reason, if I reduce the size of my Dataset, the problem is solved, but I cannot figure out why this is the case.

This might point to a memory increase in each iteration, which might not be causing the OOM anymore, if you are reducing the number of iterations.
Check the memory usage in your code e.g. via torch.cuda.memory_summary() or torch.cuda.memory_allocated() inside the training iterations and try to narrow down where the increase happens (you should also see that e.g. loss.backward() reduces the memory usage).