Help for finding memory leak

Hi,
I tried something a bit unsual. It is working, but sadly there is a memory leak. Maybe someone can help me finding it. Basicially I tried to create a gradient based integration function:
Fit a model with the gradient values of a function. Here I am fitting both, the y values and the gradients in parallel.

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch import FloatTensor as FT


def parabel(a, b, c, x):
    return a*x**2 + b*x + c


class ParabelModel(nn.Module):
    def __init__(self, start_params=[1, 1, 0]):
        super().__init__()
        self.params = nn.Parameter(FT(start_params))
        
    def forward(self, x):
        a, b, c = self.params
        return parabel(a, b, c, x)


def print_usage(device):
    if device.type == 'cuda':
        actual_mem = torch.cuda.memory_allocated() 
        max_mem = torch.cuda.max_memory_allocated()
        print(f"{actual_mem/max_mem :0.2f} actual: {actual_mem:.2e} max: {max_mem:.2e}")
    return


a, b, c = -2, 3, 2
N = 100

device = torch.device('cuda')
# create data
xs_train = torch.linspace(-1, 1, N, requires_grad=True, device=device)
ys_train = parabel(a, b, c, xs_train)
ys_train.backward(torch.ones_like(ys_train))
dx_train = xs_train.grad
xs_train = xs_train.detach()
ys_train = ys_train.detach()

l1 = 1
l2 = 0.00
lr = 0.5

model = ParabelModel([2, 1, 0]).to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr)


for epoch in range(400):
    xs = xs_train.clone().detach().requires_grad_(True)
    ys_pred = model(xs)

   # calculate gradients of the model wrt the inputs
    ys_pred.backward(torch.ones_like(ys_pred), create_graph=True)
    dx_pred = xs.grad

    loss_grad = l1 * F.mse_loss(dx_pred, dx_train)

    model.zero_grad()

    loss_grad.backward()

    loss = l2 * F.mse_loss(ys_pred, ys_train)
    loss.backward()
    
    with torch.no_grad():
        optimizer.step()
        optimizer.zero_grad()
        
    
    if not epoch % 50:
        print_usage(device)
        print(model.params.data)

This gives:
0.63 actual: 6.14e+03 max: 9.73e+03
tensor([1.4900, 1.4950, 0.0000], device=‘cuda:0’)
0.41 actual: 5.73e+04 max: 1.39e+05
tensor([-2.1686, 3.0714, 0.0000], device=‘cuda:0’)
0.57 actual: 1.09e+05 max: 1.90e+05
tensor([-1.9972, 2.9907, 0.0000], device=‘cuda:0’)
0.66 actual: 1.60e+05 max: 2.41e+05
tensor([-1.9885, 2.9921, 0.0000], device=‘cuda:0’)
0.72 actual: 2.11e+05 max: 2.92e+05
tensor([-1.9894, 2.9928, 0.0000], device=‘cuda:0’)
0.76 actual: 2.62e+05 max: 3.44e+05
tensor([-1.9906, 2.9937, 0.0000], device=‘cuda:0’)
0.79 actual: 3.13e+05 max: 3.95e+05
tensor([-1.9915, 2.9943, 0.0000], device=‘cuda:0’)
0.82 actual: 3.65e+05 max: 4.46e+05
tensor([-1.9922, 2.9947, 0.0000], device=‘cuda:0’)

It’s convergeng to the expected values a, b and c but the memory grows.
Thanks

Using torch.autograd.grad (which is a better fit conceptually)

    dx_pred, = torch.autograd.grad(ys_pred, xs, torch.ones_like(ys_pred), create_graph=True)

removes the memory leak.
Note that torch.tensor is preferred over the constructors like FloatTensor and we generally avoid .data. You don’t need the no_grad for calling opt.step.

That said, using backward that way probably shouldn’t leak memory. @albanD might know if there is a story to it.

Best regards

Thomas

Hi,
I really thank you a lot for your answer. It fixes my problem :grinning: But I have to say the manipulation of the dynamic graph is still complicated for me. Do you know of a good (maybe visual) explanation of that topic? Something like a correspondance between the pytorch commands and what is happening to the graph.

The most detailed account is in @ptrblck and my imaginary book which focuses on PyTorch details like that (and tries to avoid explaining deep learning).
The other day I wrote an extension that shows the autograd graph in Jupyter at the end of each cell (or when you want) to create the visuals and facilitate interactive exploration. But didn’t think to include .grad and I’ll admit that I don’t think it’ll say “memory leak”, so it would be an imperfect solution for the thing here.
If the present book is any indication, it’ll be there in a matter of years. Maybe we should have an early access component. :stuck_out_tongue_winking_eye:

1 Like

Hey,

Yes we added a note to the doc about this: https://pytorch.org/docs/master/autograd.html#torch.autograd.backward

The short answer is:

  • If possible, use autograd.grad, it will avoir this, issues with accumulating unrelated gradients in the .grad field if you forget to .zero_grad and avoid computing un-needed gradients if you only want them for the input.
  • If you cannot, replace model.zero_grad() after the backward that create_graph=True by
for p in model.parameters():
    p.grad = None

This will be as fast but will properly remove the leak :slight_smile:

1 Like

Thank you aswell. I learned a lot today.

Thank you Alban!

It’s not the parameters here, btw. Here is what happens as far as I can tell:

  • Obviously, xs holds a reference to xs.grad if that is populated, as it is a member.
  • But via the graph built with backward(..., create_graph=True) you have that xs.grad via its grad_fn and some next_functions has an indirect reference to xs.

This reference cycle isn’t just in Python but rather in C++, so while the Python bits will go out of scope and can be garbage collected eventually, the C++ doesn’t have garbage collection and so the C++ side of the tensors will remain alive.
If you set xs.grad to None, it’ll remove the C++ link and so the C++ bits will be deallocated as well.

Best regards

Thomas

1 Like