Memory leak when using forward hook and backward hook simultaneously

When I use simple forward hook and backward hook on any CNN model, memory leak occurs.
Here’s simple snippet to reproduce the phenomenon.

def bar():
    model = models.resnet101()
    a, b = None, None
    target_layer = model.layer4[-1].conv3
    def forward_hook(module, input, output):
        nonlocal a
        a = output.clone()

    def backward_hook(module, grad_input, grad_output):
        nonlocal b
        b = grad_output[0].clone()
    image = torch.randn(1, 3, 224, 224)
def main():
    for _ in range(10):
        cnt = 0
        for obj in gc.get_objects():
                if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(
                    cnt += 1
        print('cnt', cnt)

When I run the above code with pytorch 0.4.0, the final cnt value is nonzero, which is accumulated every time I call the function bar.

The interesting point is that if I do any one of following options, the cnt becomes zero.

  1. del a, b at the end of the function bar.
  2. comment out register_forward_hook (backward hook still working)
  3. comment out register_backward_hook (forward hook still working)
  4. comment out nonlocal a
  5. comment out nonlocal b

Any help would be greatly appreciated. Thanks!

Could you post this to GitHub as an issue? This looks like a bug to me.


1 Like

This is “expected” behaviour as stated in this comment in the code.

For the record,
In your case you trigger this because:

  • The last Function on which you put the backward hook contains the backward hook.
  • This backward hook function because of the partial, references the parent nn.Module
  • This nn.Module contains the forward hook you registered
  • This forward hook contains the forward_hook python function
  • This function references to its closure
  • This closure references the nonlocal variable a
  • This variable a is linked via the backward graph to the output of the last Function.
  • This output has a grad_fn field linking to the last Function which is not traversed.

Tthe modified foo:

class MyFn(torch.autograd.Function):
    def forward(ctx, inp):
        return inp.clone()

    def backward(ctx, gout):
        return gout.clone()

class MyMod(nn.Module):
    def forward(self, inp):
        return MyFn.apply(inp)

def foo():
    mod = MyMod()
    inp = torch.rand(2, 3, requires_grad=True)

    l = [None]

    def fw_hook(mod, inp, res):
        l[0] = res

    def bw_hook(mod, g_inp, g_res):


    out = mod(inp)

And the corresponding dependence graph:

Note that the missing link between the tensor and the Function prevent the gc from detecting the cycle.


So is the problem originated from the fact-when there is a tensor, we don’t let the python gc knows that the tensor has a reference to its grad_fn-? Since if we allow it, python gc has to traverse back the deep gradient dependency graph, which brings significant efficiency loss?

Yes exactly, you would need to traverse the whole graph every time you delete a python object that references it.

1 Like