Grad_in, grad_out during "full backward hook" are not freed and memory leak occurs

Hi,

I found that when backward() is called with create_graph=True, “full backward hook” makes memory leak.
(non-full backward hook doesn’t have this issue.)
It seems that grad_in and grad_out are not freed, as the below code and result show. (using pytorch_memlab)

I’ve also made .grad of module parameters None, following the warning message,

UserWarning: Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak.

Is it expected behaviour?

[reproducing code …]

import torch
import torch.nn as nn

from pytorch_memlab import MemReporter


class TestModel(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.nn = nn.Linear(input_size, output_size)
        self._register_backward_hooks()

    def forward(self, x):
        pred = self.nn(x)
        return pred

    def _register_backward_hooks(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.register_full_backward_hook(idle_hook)
                # m.register_backward_hook(idle_hook)

def idle_hook(m, grad_in, grad_out):
    return

batch_size = 50
in_features = 200
out_features = 10
model = TestModel(in_features, out_features)

for i in range(5):
    x = torch.randint(low=0, high=2, size=(batch_size, in_features)).float()
    y = torch.randn(size=(batch_size, out_features)).float()

    pred = model(x)
    loss = torch.sum((y - pred) ** 2, -1).mean()
    loss.backward(retain_graph=True, create_graph=True)

    for param in model.parameters():
        param.grad = None

    print(f'Iteration #{i}')
    reporter = MemReporter(model)
    reporter.report()

[result…]

Iteration #0
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cpu
nn.weight                                          (10, 200)     8.00K
nn.bias                                                (10,)   512.00B
Tensor0                                            (50, 200)    39.50K
Tensor1                                             (50, 10)     2.00K
Tensor2                                             (50, 10)     2.00K
Tensor3                                             (50, 10)     0.00B
Tensor4                                             (50, 10)     2.00K
Tensor5                                                 (1,)   512.00B
Tensor6                                             (50, 10)     2.00K
-------------------------------------------------------------------------------
Total Tensors: 14511 	Used Memory: 56.50K
-------------------------------------------------------------------------------


Iteration #4
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cpu
nn.weight                                          (10, 200)     8.00K
nn.bias                                                (10,)   512.00B
Tensor0                                            (50, 200)    39.50K
Tensor1                                             (50, 10)     2.00K
Tensor2                                             (50, 10)     2.00K
Tensor3                                            (50, 200)    39.50K
Tensor4                                             (50, 10)     2.00K
Tensor5                                             (50, 10)     2.00K
Tensor6                                            (50, 200)    39.50K
Tensor7                                             (50, 10)     2.00K
Tensor8                                             (50, 10)     2.00K
Tensor9                                            (50, 200)    39.50K
Tensor10                                            (50, 10)     2.00K
Tensor11                                            (50, 10)     2.00K
Tensor12                                           (50, 200)    39.50K
Tensor13                                            (50, 10)     2.00K
Tensor14                                            (50, 10)     2.00K
Tensor15                                            (50, 10)     0.00B
Tensor16                                            (50, 10)     2.00K
Tensor17                                                (1,)   512.00B
Tensor18                                            (50, 10)     2.00K
-------------------------------------------------------------------------------
Total Tensors: 58511 	Used Memory: 230.50K
-------------------------------------------------------------------------------