When I use simple forward hook and backward hook on any CNN model, memory leak occurs.
Here’s simple snippet to reproduce the phenomenon.
def bar():
model = models.resnet101()
model.eval()
a, b = None, None
target_layer = model.layer4[-1].conv3
def forward_hook(module, input, output):
nonlocal a
a = output.clone()
def backward_hook(module, grad_input, grad_output):
nonlocal b
b = grad_output[0].clone()
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)
image = torch.randn(1, 3, 224, 224)
model(image)
def main():
for _ in range(10):
bar()
gc.collect()
cnt = 0
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
cnt += 1
except:
pass
print('cnt', cnt)
When I run the above code with pytorch 0.4.0, the final cnt value is nonzero, which is accumulated every time I call the function bar.
The interesting point is that if I do any one of following options, the cnt becomes zero.
del a, b at the end of the function bar.
comment out register_forward_hook (backward hook still working)
comment out register_backward_hook (forward hook still working)
So is the problem originated from the fact-when there is a tensor, we don’t let the python gc knows that the tensor has a reference to its grad_fn-? Since if we allow it, python gc has to traverse back the deep gradient dependency graph, which brings significant efficiency loss?