Register_full_backward_hook causes memory leak

Hi,

It seems like there is a bug with the new register_full_backward_hook method that causes memory leak while the old register_backward_hook doesn’t have the problem.

The following code snippet uses register_full_backward_hook.

import torch
import torch.nn as nn
import torch.nn.functional as F

from pytorch_memlab import MemReporter

def _make_encoder_layer(in_channels, out_channels, kernel_size=3):
    conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)
    pool = nn.MaxPool2d(2, stride=2)
    norm = nn.BatchNorm2d(out_channels)
    act = nn.ReLU()

    return nn.Sequential(conv, pool, norm, act)

def hook_func(module: nn.Module, _inputs, _outputs):
    pass


class TestModule(nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
        self.encoder = _make_encoder_layer(1, 3)
        self.grad_handles = []

    def forward(self, x):
        self._register_backward_hooks()

        x.requires_grad = True
        h = self.encoder(x)

        torch.autograd.backward(h.mean(), create_graph=True)
        # do something with grad
        blah = x.grad
        self.zero_grad()

        self._remove_backward_hooks()
        return h


    def _register_backward_hooks(self):
        # Iterate through layers
        for m in self.encoder.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d):
                handle_g = m.register_full_backward_hook(hook_func)
                self.grad_handles.append(handle_g)


    def _remove_backward_hooks(self):
        for h in self.grad_handles:
            h.remove()
        self.grad_handles = []
            

model = TestModule()
reporter = MemReporter(model)

for i in range(10):
    x = torch.rand(size=(32, 1, 128, 128))

    p = model(x)

    p.mean().backward()

    if i == 0 or i == 9:
        print("")
        print(f"On the {i+1}th iteration.")
        print("")
        reporter.report()

pytorch_memlab is used to inspect the memory usage of tensor variables. The above snippet gives the following output:

On the 1th iteration.
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cpu
Tensor0                                    (32, 1, 128, 128)     2.00M
Tensor1                                                 (1,)   512.00B
Tensor2                                      (32, 3, 64, 64)     1.50M
Tensor3                                      (32, 3, 64, 64)     1.50M
Tensor4                                    (32, 3, 128, 128)     6.00M
Tensor5                                                 (3,)   512.00B
Tensor6                                                 (3,)   512.00B
encoder.0.weight                                (3, 1, 3, 3)   512.00B
encoder.0.weight.grad                           (3, 1, 3, 3)   512.00B
encoder.0.bias                                          (3,)   512.00B
encoder.0.bias.grad                                     (3,)   512.00B
encoder.2.weight                                        (3,)   512.00B
encoder.2.weight.grad                                   (3,)   512.00B
encoder.2.bias                                          (3,)   512.00B
encoder.2.bias.grad                                     (3,)   512.00B
-------------------------------------------------------------------------------
Total Tensors: 2883663 	Used Memory: 11.01M
-------------------------------------------------------------------------------

On the 10th iteration.

Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cpu
Tensor3                                      (32, 3, 64, 64)     1.50M
Tensor4                                    (32, 3, 128, 128)     6.00M
Tensor7                                      (32, 3, 64, 64)     1.50M
Tensor8                                    (32, 3, 128, 128)     6.00M
Tensor9                                      (32, 3, 64, 64)     1.50M
Tensor10                                   (32, 3, 128, 128)     6.00M
Tensor11                                     (32, 3, 64, 64)     1.50M
Tensor12                                   (32, 3, 128, 128)     6.00M
Tensor13                                     (32, 3, 64, 64)     1.50M
Tensor14                                   (32, 3, 128, 128)     6.00M
Tensor15                                     (32, 3, 64, 64)     1.50M
Tensor16                                   (32, 3, 128, 128)     6.00M
Tensor17                                     (32, 3, 64, 64)     1.50M
Tensor18                                   (32, 3, 128, 128)     6.00M
Tensor19                                     (32, 3, 64, 64)     1.50M
Tensor20                                   (32, 3, 128, 128)     6.00M
Tensor21                                     (32, 3, 64, 64)     1.50M
Tensor22                                   (32, 3, 128, 128)     6.00M
Tensor23                                   (32, 1, 128, 128)     2.00M
Tensor24                                                (1,)   512.00B
Tensor25                                     (32, 3, 64, 64)     1.50M
Tensor26                                     (32, 3, 64, 64)     1.50M
Tensor27                                   (32, 3, 128, 128)     6.00M
Tensor5                                                 (3,)   512.00B
Tensor6                                                 (3,)   512.00B
encoder.0.weight                                (3, 1, 3, 3)   512.00B
encoder.0.weight.grad                           (3, 1, 3, 3)   512.00B
encoder.0.bias                                          (3,)   512.00B
encoder.0.bias.grad                                     (3,)   512.00B
encoder.2.weight                                        (3,)   512.00B
encoder.2.weight.grad                                   (3,)   512.00B
encoder.2.bias                                          (3,)   512.00B
encoder.2.bias.grad                                     (3,)   512.00B
-------------------------------------------------------------------------------
Total Tensors: 20578383 	Used Memory: 78.51M
-------------------------------------------------------------------------------

Somehow there are a lot of tensors that are not recycled by the garbage collection mechanism and increasing the number of loops will eventually cause the run out of memory RuntimeError.

However, switching back to the old method register_backward_hook solves this issue.

It is interesting that switching back to register_backward_hook solves the issue. The issue with the memory leak should be that you are calling .backward() with create_graph=True, which creates a reference cycle. This is documented here: Automatic differentiation package - torch.autograd — PyTorch 1.8.1 documentation. We plan to add a warning for this soon…

I am not sure I understand why there’s a reference cycle created. Could you please elaborate on this?
Also, I called self.zero_grad() after the backward operation. If there’s any reference cycle created, they are supposed to be removed by zero_grad(), right?

Tensor holds a reference to its .grad. When we set create_graph=True, we create the backward graph for .grad, so that we can perform double backward on it. In most cases this .grad will be a function of the original input, which means that there will be a chain of backward nodes leading back to the grad_accumulator of the input. Finally this grad_accumulator has an owning reference to its Tensor.

The cycle is: Tensor → .grad → .grad_fn → … → grad_accumulator → .variable → Tensor
When you call zero_grad, you don’t actually break the reference cycle because you are just setting the tensor’s values to zero, the tensor would still hold a reference to .grad, which still holds a reference to its gradfn.
To break the cycle you’d need to do x.grad=None.

2 Likes