GPU memory not cleared after training when saving stats for plotting

When I check what the memory usage on my GPU is in training and validation iterations, I notice that something from the training iteration is not released during the validation iteration. Output of print statements:

...
training: 
3940352
validation: 
131654144
training: 
3940352
validation: 
131654144
...

For reference, my code looks something like this:

def iteration(step, batch, model, device, stats_dict, optim_schedule=None, criterion=None, train=True):
    if train:
        model.train()
    else:
        model.eval()
    
    batch = {key: value.to(device) for key, value in batch.items()}
    output = model.forward(batch['input_ids'])
    loss = criterion(output, batch['label']) 
    if train:
        optim_schedule.zero_grad()
        loss.backward()
        optim_schedule.step_and_update_lr()
    
    # for debugging GPU memory issues:
    if train: 
        print('training: ') 
    else: 
        print('validation: ')
    print(torch.cuda.memory_allocated())

    #################################
    # gather loss and correct prediction statistics for this data chunk
    #################################
    predicted = prot_output.argmax(dim=-1).cpu()
    num_correct = (predicted == batch['label'].cpu()).sum().item()   # attempt at solving the bug by moving to CPU

    stats_dict['loss'] += loss.item()
    stats_dict['num_correct'] += num_correct
    stats_dict['nb_tr_examples'] += batch['input_ids'].size(0)
    stats_dict['nb_tr_steps'] += 1
    return stats_dict   

And in the main function:

def main():
    # ....

    # Accumulate loss for plotting training history
    train_hist = {'loss': [],
                  'acc': []}
    
    val_hist = {'loss': [],
              'acc': []}

    train_stats = {'loss': 0,
             'num_correct': 0,
             'nb_tr_steps': 0,
             'nb_tr_examples': 0,
           }

    val_stats = {'loss': 0,
             'num_correct': 0,
             'nb_tr_steps': 0,
             'nb_tr_examples': 0,
           }
    
    #################################
    # Start looping through the epochs    
    ################################

    for epoch in tqdm.trange(int(args.num_train_epochs), desc="Epoch"):
        for step, batch in enumerate(tqdm.tqdm(train_dataloader, desc="Iteration")):
            if args.train_dataset:
                train_stats = iteration(step, batch, model, device, train_stats, optim_schedule, criterion, train=True)
                
                if step % args.time_steps_per_plot_point == 0:
                    # let's add a new point to the array that will be plotted:
                    train_hist['loss'].append(train_stats['loss'] / train_stats['nb_tr_steps'])
                    train_hist['acc'].append(train_stats['num_correct'] / train_stats['nb_tr_examples'])

            if args.val_dataset:
                val_stats = iteration(step, batch, model, device, val_stats, optim_schedule, criterion, train=False)
                
                if step % args.time_steps_per_plot_point == 0:
                    val_hist['loss'].append(val_stats['loss'] / val_stats['nb_tr_steps'])
                    val_hist['acc'].append(val_stats['num_correct'] / val_stats['nb_tr_examples'])

I’m relatively new to PyTorch, and I’d appreciate any tips on how to debug this or writing PyTorch efficient code. I’ve seen GPU memory not fully released after training loop, but if the stats_dict is what is causing the problem, what is the best way to still track performance for plotting?