Global pruning uses increasing amounts of memory

I’m performing pruning on a simple model but am running into GPU out-of-memory errors. At each consecutive pruning iteration more memory is used, rather than maintaining a constant amount (see outputs below). Pruning using the “l1_unstructured” function works fine and doesn’t use increasing memory, while the “global_unstructured” function has this issue.

Am I doing something wrong here or this a bug?

Here’s simple code to reproduce the error:

import torch
import torch.nn.utils.prune as prune

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(2000, 2000)
    
model = Net()
model.to(device="cuda:0")

iter = 0
for _ in range(1000):

    # Doesn't have any memory issue:
    # prune.l1_unstructured(model.fc, name="weight", amount=0.01)

    # Rapidly uses memory:
    prune.global_unstructured(
        [(model.fc, "weight")],
        pruning_method=prune.L1Unstructured,
        amount=0.01,  # fails with 0, 0.01, or 2
    )

    iter += 1
    if iter % 100 == 0:
        print(f"\nIter {iter}: allocated mem (gb) = {torch.cuda.memory_allocated(0) / 1e9}")

And the output, where it’s using multiple gb of gpu memory after only a couple hundred iterations:

Iter 100: allocated mem (gb) = 1.6775424

Iter 200: allocated mem (gb) = 3.303190528

Iter 300: allocated mem (gb) = 4.929615872

Iter 400: allocated mem (gb) = 6.555264

This issue occurs for the following “amount” values: 0, 0.01, 2. All of these values are fine when using the “l1_unstructured” function.

I’m using PyTorch 2.1.2+cu121.

After some investigation I found that global pruning uses a torch.nn.utils.prune.PruningContainer which stores the entire history of the pruning mask during iterative pruning. It probably wasn’t designed for iterative pruning using 100s-1000s of steps, like what I’m trying to do.

Ideally there’d be an option to not store the entire history. For now, calling prune.remove() every ~10-100 iterations is a workaround to delete all the masks and free memory. Calling this after every iteration appears to add some delay so I’m only doing it when needed.