Out of Memory and Can't Release GPU Memory

I use try-catch to enclose the forward and backward functions, and I also delete all the tensors after every batch.

            try:
                decoder_output, loss = model(batch)
                if Q.qsize() > 0:
                    break
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_loss.append(loss.mean().item())
                del decoder_output, loss
            except Exception as e:
                optimizer.zero_grad()
                for p in model.parameters():
                    if p.grad is not None:
                        del p.grad
                torch.cuda.empty_cache()
                oom += 1
            del batch
            for p in model.parameters():
                if p.grad is not None:
                    del p.grad
            torch.cuda.empty_cache()

The training process is normal at the first thousands of steps, even if it got OOM exception, the exception will be catched and the GPU memory will be released.

But after I trained thousands of batches, it suddenly keeps getting OOM for every batch and the memory seems never be released anymore.

It’s so weird to me, is there any suggestions? (I’m using distributed data parallel)

Here is the code before the try-catch, i.e., the code to prepare input data with train_loader:

        for i, batch in enumerate(train_loader):
            batch = {k: v.squeeze(0).cuda(non_blocking=True) for (k, v) in batch.items()}
            batch['rem_father_index'] = torch.split(batch['rem_father_index'], batch['rem_root_num'].tolist(), dim=0)
            batch['rem_father_index'] = [l.tolist() for l in batch['rem_father_index']]
            batch['tree_sizes'] = batch['tree_sizes'].tolist()

I have some updates:

  1. If I create some tensors like: torch.zeros(bsz, max_len), will this cause some side-effect?
  2. I output the python garbage collector, and find something wierd:

    There is a huge dict, which contains a lot of tensors, is this part of the graph or something?

I found something similar now…Did you find any solution?

This issue only occurs with DistributedDataParallel.

Finally, I go back to DataParallel and it doesn’t have the problem, although it’s slower.

It happens to me without DDP :frowning:
I opened an issue here, possibly related: OOM during backward() leads to memory leaks · Issue #82218 · pytorch/pytorch · GitHub