MWE for memory leaking?

As an exercise for my students, I am creating various small pytorch examples that exhibit different issues, and then the students have to debug and solve those issues. I would like to include an example with memory leakage on the GPU.
However, I am having a hard time actually making a small example that exhibits memory leakage. This was my best attempt:

import torch
import torch.nn as nn
device=torch.device('cuda')

input_size=500
hidden_size=700

Xset=torch.utils.data.TensorDataset(torch.rand(1000000,input_size),torch.rand(1000000,1))

lossF=nn.functional.huber_loss

Bob_net=nn.Sequential(nn.Linear(input_size,hidden_size),
                  nn.ReLU(),
                  nn.Linear(hidden_size,hidden_size),
                  nn.ReLU(),
                  nn.Linear(hidden_size,hidden_size),
                  nn.ReLU(),
                  nn.Linear(hidden_size,hidden_size),
                  nn.ReLU(),
                  nn.Linear(hidden_size,hidden_size),
                  nn.ReLU(),
                  nn.Linear(hidden_size,1))



trainLoader=torch.utils.data.DataLoader(Xset,batch_size=64)

Bob_net.to(device)
optimizer=torch.optim.Adam(Bob_net.parameters())
losses=[]

for iEpoch in range(30):
  print(f"Allocated memory: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
  for xbatch,ybatch in trainLoader:
    xbatch=xbatch.to(device)
    ybatch=ybatch.to(device)

    pred=Bob_net(xbatch)
    loss=lossF(ybatch,pred)

    Bob_net.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss)

Here, the allocated memory does grow, but it seems to be simply because the loss is not moved from the GPU, and the rising memory requirements are simply due to having a long list of tensors on the GPU. The intention was that since the loss is never detached from the graph, all graphs would be retained, leading to a significant increase in memory usage. However, the growth in memory allocation is the same as if I just do:

L=[]
for i in range(468750):
  L.append(torch.rand(1).to(device))

(where 468750 happens to be the length of the ‘losses’ list from the first example).

So, does anyone have an idea for how to make a small example of what can go wrong if you keep references to the graph? Or is it me who have misunderstood the point of ‘detach()’?

I’m not sure detach() is the best way to show memory leakage on the GPU. detach() is similar to requires_grad=False but for an entire sub-graph.
Imagine a simple matrix multiplication where one input doesn’t require gradients, and we just want to let the gradient flow through the other input (from output):

x = torch.randn(4,4)
y = torch.randn(4,4,requires_grad=False)
z = torch.matmul(x,y)

now imagine y is the result of some complicated calculation, which has large inputs:

y0 = torch.randn(1024*32, 1024, requires_grad=True)
y1 = torch.randn(1024, 1024*32, requires_grad=True)
y2 = torch.matmul(y0, y1)
y3 = torch.matmul(y2, torch.randn(1024*32, 4))
y4 = torch.matmul(y2, torch.randn(1024*32, 4))
y = torch.matmul(y3.t(), y4)
y = y.detach()
# rest of the model before x/y
x = torch.randn(4,4)
z = torch.matmul(x,y)
# rest of the model after z

In this case, it makes a significant difference whether y.detach() is called before z = torch.matmul(x,y) : if you don’t call detach(), we need to save the large matrixes y0 / y1 for backward.

Something like this happens in physics simulation models where you might have multiple paths, corresponding to different types of simulation, through the model to update parameters (y0 / y1 here, plus imagine having lots more for other paths), but you only want gradients to flow through one path at a time and update specific parameters, depending on the type of loss (which might be tuned for a given type of simulation).

Autograd might be smarter than that at this point, pruning the graph after being used in backward. Try modifying your code to store predictions, instead, like this:

all_preds = []  # Instead of losses
for xbatch, ybatch in trainLoader:
    xbatch = xbatch.to(device)
    ybatch = ybatch.to(device)
    pred = Bob_net(xbatch)  # Full graph here
    all_preds.append(pred)  # Retains graph for each batch's activations!
    # Skip backward/step to avoid pruning

# Now compute loss on full cat(all_preds) — graph explodes!
final_pred = torch.cat(all_preds)
loss = lossF(ybatch_all, final_pred)  # ybatch_all would need accumulation too

Or you could update your loss.backward in the original script to this loss.backward(retain_graph=True)

1 Like

thanks. Though, I don’t think I can come up with a plausible reason why anyone would do it like that in the first place. Which means I still don’t have a good example that the students can be asked to debug…
I guess at this point the take away is that GPU memory leakage is largely a myth?

I’m not sure on that. But what I can suggest is using the search function for “out of memory error” or “memory leak” and that might find more practical examples.

If I recall, tends to still be an issue(solvable) with RNNs, or in certain RL setups.

why is avoiding to detach a large graph something that no one would do in the first place?
in complicated networks, it can easily happen.
If the argument is that “no one would do it this way”, why would we have memory leaks from a c++ function returning a heap-allocated pointer? If you follow standard programming practices, you wouldn’t do it this way, right?

I see the expected memory increase using:

import torch
import torch.nn as nn 

device = "cuda"

batch_size = 1024 * 1024
hidden_dim = 1

model = nn.Sequential(
    nn.Linear(hidden_dim, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim)
).to(device)

print("{:.3f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))

l = []
for _ in range(10):
    x = torch.randn(batch_size, hidden_dim, device=device)
    out = model(x)
    l.append(out)
    print("{:.3f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))    

# 0.008MB allocated
# 44.133MB allocated
# 80.133MB allocated
# 116.133MB allocated
# 152.133MB allocated
# 188.133MB allocated
# 224.133MB allocated
# 260.133MB allocated
# 296.133MB allocated
# 332.133MB allocated
# 368.133MB allocated

which shows the memory increase of ~36MB in each iteration corresponding to the 8 intermediates and 1 output:

batch_size * hidden_dim * 4 / 1024**2 * (8+1) = 36

The additional 8MB shown in the first iteration is used as a workspace in cuBLAS.

Note, however, that this memory increase is not a memory leak strictly speaking as the user is explicitly appending the outputs and thus can still call backward on them afterwards. A leak would indicate lost memory that cannot be freed or reused.