Memory leak when using forward hook and backward hook simultaneously

This is “expected” behaviour as stated in this comment in the code.

For the record,
In your case you trigger this because:

  • The last Function on which you put the backward hook contains the backward hook.
  • This backward hook function because of the partial, references the parent nn.Module
  • This nn.Module contains the forward hook you registered
  • This forward hook contains the forward_hook python function
  • This function references to its closure
  • This closure references the nonlocal variable a
  • This variable a is linked via the backward graph to the output of the last Function.
  • This output has a grad_fn field linking to the last Function which is not traversed.

Tthe modified foo:

class MyFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp):
        return inp.clone()

    @staticmethod
    def backward(ctx, gout):
        return gout.clone()

class MyMod(nn.Module):
    def forward(self, inp):
        return MyFn.apply(inp)

def foo():
    mod = MyMod()
    inp = torch.rand(2, 3, requires_grad=True)

    l = [None]

    def fw_hook(mod, inp, res):
        l[0] = res

    def bw_hook(mod, g_inp, g_res):
        pass

    mod.register_forward_hook(fw_hook)
    mod.register_backward_hook(bw_hook)

    out = mod(inp)
    print(out.grad_fn.next_functions)

And the corresponding dependence graph:


Note that the missing link between the tensor and the Function prevent the gc from detecting the cycle.

2 Likes