Get torch.cuda.OutOfMemoryError with batch size = 1

Hi! I recently get torch.cuda.OutOfMemoryError even when my batch is only 1. The gpu is GeForce RTX 3090 with total capacity of 23 GB. The model is a simple pretrained BERT. My task is hierarchical text classification, here is my code:

    LM = LM_model.from_pretrained('bert-base-uncased').to(device)

    optimizer_lm = AdamW(LM.parameters(), lr = args.lr_lm)

    loss_fn = nn.BCEWithLogitsLoss()

    for epoch in range(args.max_epochs):
        torch.cuda.empty_cache()
        ################# start training ##################

        LM.train()
        idx = 0
        loss_val = 0.

        for data in tqdm(dataset['train']):  # one text at a time
            # data: list[(str) text, (int) label id ...]
            level = 0
            root_names = taxonomy[f'l0_list'] # list of label names

            if idx % args.batch_size == 0:
                loss = torch.tensor(0.).to(device) 

            while level < args.depth:
                cur_pred_score = 0 # label_name, pred_score.
                cur_pred_name = ""
                cur_loss = torch.tensor(0.).to(device)

                for root_name in root_names:
                    input = {}
                    for k,v in tokenizer(data[0], root_name, truncation = "only_first").items():
                        input[k] = torch.tensor(v).to(device).unsqueeze(dim = 0)

                    pred = LM(input = input)

                    for k in ["input_ids", "token_type_ids", "attention_mask"]:
                        del input[k]

                    if pred.item() > cur_pred_score:
                        cur_pred_score = pred.item()
                        cur_pred_name = root_name
                    if data[level + 1] == taxonomy[f'l{level}_label2id'][root_name]: #true label
                        cur_loss += loss_fn(pred, torch.tensor([[1.]]).to(device)) #score = tensor(0.4912, device='cuda:1'); score.item(): float
                    else: #false label
                        cur_loss += loss_fn(pred, torch.tensor([[0.]]).to(device))

                    del pred 
             
                if taxonomy[f'l{level}_label2id'][cur_pred_name] != data[level+1]:   #false prediction of this layer 
                    loss += cur_loss * (args.depth - level)
                    del cur_loss
                    break
                else: #correct prediction of this layer 
                    loss += cur_loss
                    del cur_loss
                    if level < args.depth - 1: # not leaf label
                        root_names = []
                        root_id = data[level+1]
                        ch_id_list = taxonomy[f'l{level}_to_l{level+1}'][root_id]
                        for id in ch_id_list:
                            root_names.append(taxonomy[f'l{level+1}_id2label'][id])
            
                level += 1

            
            if (idx + 1) % args.batch_size == 0: 
                loss_val += loss.item()
                optimizer_lm.zero_grad()
                loss.backward()
                optimizer_lm.step()
                #del loss
                torch.cuda.empty_cache() 

            idx += 1
            
        ################ training end ########################
        
        torch.cuda.empty_cache()
      

Since for data in tqdm(dataset['train']) get one text at a time, I suppose my batch size is 1. I use args.batch_size and idx to determine whether I do the backward and optimize. When I set args.batch_size = 5, I get this:

  File "root/.conda/envs/hvb/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 2 has a total capacity of 23.69 GiB of which 7.94 MiB is free. Including non-PyTorch memory, this process has 23.67 GiB memory in use. Of the allocated memory 23.31 GiB is allocated by PyTorch, and 53.07 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Uncaught exception. Entering post mortem debugging
Running 'cont' or 'step' will restart the program
> root/.conda/envs/hvb/lib/python3.8/site-packages/torch/nn/modules/linear.py(116)forward()
-> return F.linear(input, self.weight, self.bias)

After I set args.batch_size = 2, the problem is gone, but I’m still not sure what’s happening on the gpu, because it seems I’ve freed any redundant tensors, my model is not large, my gpu has a large capacity. Any help would be appreciated!