Model inference consumes memory, every time model(input) is called

CPU memory gets accumulated everytime I call model for inference

def model_inference(
    model,
    device,
    input_ids
):
    input_ids = input_ids.to(device)

    with torch.no_grad():
        outputs = model(input_ids = input_ids, output_hidden_states = True)
    
    torch.cuda.empty_cache()
   
    gc.collect()

     
    #print("+"*5, psutil.virtual_memory().available/(1024 ** 3))

    return outputs
    
## Memory profiler of this function is below: If you can see that at line
    138, the program is using 5967 MB of memory. And it add 316 MB memory
    when it calls lines 140-141. But then it never frees this memory.
    And this is the memory which keeps on getting accumulated 
    and results in memory consumption 
    Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   132   5967.3 MiB   5967.3 MiB           1   @profile
   133                                         def model_inference(
   134                                             model,
   135                                             device,
   136                                             input_ids
   137                                         ):
   138   5967.3 MiB      0.0 MiB           1       input_ids = input_ids.to(device)
   139                                         
   140   6284.1 MiB      0.0 MiB           2       with torch.no_grad():
   141   6284.1 MiB    316.8 MiB           1           outputs = model(input_ids = input_ids, output_hidden_states = True)
   142  
   153   6284.1 MiB      0.0 MiB           1       torch.cuda.empty_cache()
 
   157   6284.1 MiB      0.0 MiB           1       gc.collect()
   158                                         
   159                                              
   160                                         
   161                                         
   162                                            
   163                                             #print("+"*5, psutil.virtual_memory().available/(1024 ** 3))
   164                                         
   165   6284.1 MiB      0.0 MiB           1       return outputs

I have this function in my code when I used inference on lines outputs = model(input_ids = input_ids, output_hidden_states = True), it keeps on consuming 300-600 MBs of memory from the allocated available memory. Even though if I use garbage collection, it still adds up this memory, every time I call this model_inference function and then runs out of memory. I am passing a sample at a time

Can you please suggest me what should be done here ?