Out of Memory after 1000 training + 500 validation steps

I am using a huggingface pre-trained transformer to train on a nlp objective. I have defined my training loop as follows:

for epoch in range(nepochs):
    train_loss = 0.0
    with tqdm(trainloader, unit="batch") as tsteps:
        for step, inputs in enumerate(tsteps):
            tsteps.set_description(f"train step {(step + 1)}")

            encoder_input_ids, decoder_input_ids, labels = inputs[0].to(Config.device), inputs[1].to(Config.device), inputs[2].to(Config.device)

            train_loss += train_step(encoder_input_ids, decoder_input_ids, labels)

            tsteps.set_postfix(loss=train_loss / (step + 1))

            if (step + 1) % 1000 == 0:

As seen above, my train and val steps are in a different functions.

def val_step(encoder_input_ids, decoder_input_ids, labels):
    logits = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, output_attentions=False, output_hidden_states=False, use_cache=False).logits
    loss = loss_fn(labels, logits)

    return loss.item(), logits

def validate(epoch):
    with torch.no_grad():
        val_loss = 0.0
        for step, inputs in enumerate(testloader):
            encoder_input_ids, decoder_input_ids, labels = inputs[0].to(Config.device), inputs[1].to(Config.device), inputs[2].to(Config.device)

            loss, logits = val_step(encoder_input_ids, decoder_input_ids, labels)
            val_loss += loss

        print(f'[epoch: {epoch + 1}, step: {step + 1:5d}] val_loss: {val_loss / (step + 1):.3f}')
        del logits
        del encoder_input_ids
        del decoder_input_ids
        del labels


I have made sure that the model does not return extra outputs like hidden states. I am also deleting the logits and inputs after each step and have also called torch.cuda.empty_cache() after the the entire validation loop.

I have a A100 GPU with 40GB memory. My model trains fine for the first 1000 steps. Then it goes to validation. The validation completes as well. However, after coming back, it throws a Out of Memory Error. I am not sure how to get around this.

train step 1:   0%|          | 0/15625 [00:02<?, ?batch/s]/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:854: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.
train step 2:   0%|          | 1/15625 [07:20<1913:23:36, 440.87s/batch, loss=0.478][epoch: 1, step:   472] val_loss: 0.408
train step 2:   0%|          | 1/15625 [07:21<1917:33:55, 441.84s/batch, loss=0.478]
OutOfMemoryError                          Traceback (most recent call last)

I have read discussions on this forum on similar scenarios. I understand that after a loop, variables that are not persisted are dropped and the memory is cleared which is evident by the 1000 successful training steps. Given that the validation is successful, it means that the validation loop memory is also not stacking up in the GPU, i.e., the memory accumulated for training is being released. However, I am not sure what is going wrong after validation, while resuming training.
If I run just the validation step it takes around 17GB GPU memory in total. However I would like it to be dropped after the validation is done. Is there a way to do this?