PyTorch memory leak reference cycle in for loop

I am facing a memory leak when iteratively updating tensors in PyTorch on my Mac M1 GPU using the PyTorch mps interface. The following is a minimal reproducible example that replicates the behavior:

import torch 

def leak_example(p1, device):
    t1 = torch.rand_like(p1, device = device) #, dim=0).detach().clone(), torch.zeros_like(ubar.detach()[:1,:,:,:], dtype = torch.float32)), dim = 0)
    u1 = p1.detach() + 2 * (t1.detach())
    B = torch.rand_like(u1, device = device)
    mask = u1 < B
    a1 = u1.detach().clone()
    a1[~mask] = torch.rand_like(a1)[~mask]
    return a1

if torch.cuda.is_available(): # cuda gpus
    device = torch.device("cuda")
elif torch.backends.mps.is_available(): # mac gpus
    device = torch.device("mps")
p1 = torch.rand(5, 5, 224, 224, device = device)
for i in range(10000):
    p1 = leak_example(p1, device)    

My Mac’s GPU memory steadily grows when I execute this loop. I have tried running it on a CUDA GPU in Google Colab and it seems to be behaving similarly, with the GPU’s Active memory, Non-releasable memory, and Allocated memory increasing as the loop progresses.

I have tried detaching and cloning the tensors and using weakrefs, to no avail. Interestingly, if I don’t reassign the output of leak_example to p1, the behavior disappears, so it really seems related to the recursive assignment. Does anyone have any idea how I could resolve this?

I cannot reproduce the issue on a 3090 and after adding print(torch.cuda.memory_allocated()/1024**2) to the for loop I get:


Is this post also related to this one you’ve created yesterday?
If so, could you create a GitHub issue for the potential memory leak on MPS?

Yes, it’s the same issue but this is a minimal reproducible example.
I also get stable memory_allocated() on Colab but the output of torch.cuda.memory_summary() appears to indicate that some other background memory allocations are growing? I’m not sure though, I’m quite inexperienced with CUDA programming.

I’ll open a GitHub issue. Thank you!

I think I found the cause, here is the GitHub issue with my own response PyTorch memory leak reference cycle in for loop, Mac M1 · Issue #91368 · pytorch/pytorch · GitHub