This is “expected” behaviour as stated in this comment in the code.
For the record,
In your case you trigger this because:
- The last Function on which you put the backward hook contains the backward hook.
- This backward hook function because of the partial, references the parent nn.Module
- This nn.Module contains the forward hook you registered
- This forward hook contains the
forward_hook
python function - This function references to its closure
- This closure references the nonlocal variable a
- This variable a is linked via the backward graph to the output of the last Function.
- This output has a
grad_fn
field linking to the last Function which is not traversed.
Tthe modified foo:
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp.clone()
@staticmethod
def backward(ctx, gout):
return gout.clone()
class MyMod(nn.Module):
def forward(self, inp):
return MyFn.apply(inp)
def foo():
mod = MyMod()
inp = torch.rand(2, 3, requires_grad=True)
l = [None]
def fw_hook(mod, inp, res):
l[0] = res
def bw_hook(mod, g_inp, g_res):
pass
mod.register_forward_hook(fw_hook)
mod.register_backward_hook(bw_hook)
out = mod(inp)
print(out.grad_fn.next_functions)
And the corresponding dependence graph:
Note that the missing link between the tensor and the Function prevent the gc from detecting the cycle.