Memory leak on register_hook associated with buffer

Hello,

in order to implement channel pruning, I register the output of activations in my network as buffers and perform a register_hook on them. Doing so produces memory leaks after deleting the model unless the registered buffers are freed (self.buffer = None) at the end of the hook function. Freeing the buffers at the end of the hook works in my use-case so the problem is kinda solved but I don’t understand why it helps.

I use torch version 1.0.1.post2.

Here is a dummy standalone code sample that reproduces the problem (requires to download CIFAR-10). I delete and create the model again inside of a loop and print the memory usage at the end of the loop. The memory usage increases after each iteration unless the hook function clears the buffer (commented line in _fisher method).

import torch
import torch.nn as nn

import torchvision.transforms as transforms
import torchvision
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader


class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.Conv_1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn_1 = nn.BatchNorm2d(32)
        self.relu_1 = nn.ReLU(inplace=True)
        self.register_buffer("Conv_1_run_fish", None)

        self.Conv_2 = nn.Conv2d(32, 320, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn_2 = nn.BatchNorm2d(320)
        self.relu_2 = nn.ReLU(inplace=True)
        self.register_buffer("Conv_2_run_fish", None)

        self.avg_pool = nn.AvgPool2d(8)
        self.linear = nn.Linear(320, 10)

    def _fisher(self, grad_output, act_name, running_fisher_name):
        # setattr(self, act_name, None)  # comment to get rid of memory leak
        pass

    def forward(self, x):
        out = self.relu_1(self.bn_1(self.Conv_1(x)))
        out.register_hook(lambda x: self._fisher(x, "Conv_1_act", "Conv_1_run_fish"))
        self.Conv_1_act = out

        out = self.relu_2(self.bn_2(self.Conv_2(out)))
        out.register_hook(lambda x: self._fisher(x, "Conv_2_act", "Conv_2_run_fish"))
        self.Conv_2_act = out

        out = self.avg_pool(out)
        out = out.view(-1, 320)
        return self.linear(out)


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = torchvision.datasets.CIFAR10(root='~/Documents/CIFAR-10', train=True, download=True,
                                           transform=transforms.ToTensor())
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0, pin_memory=False)

    for i in range(30):
        model = Network()
        model.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad],
                                    lr=8e-4, momentum=0.9, weight_decay=0.0005)

        model.train()  # switch to train mode

        for batch_in, batch_target in dataloader:
            batch_in, batch_target = batch_in.to(device), batch_target.to(device)

            # compute output
            output = model(batch_in)

            loss = criterion(output, batch_target)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        del optimizer
        del model

        torch.cuda.empty_cache()
        print(torch.cuda.memory_allocated() / 10 ** 9)

Hi,

Does adding a gc.collect() twice just after the del model helps? Looks like you’re making a circular reference of python object.

1 Like

sorry for the late reply,

I try to add it twice but it did not solve the problem (I use python 3.6.6 maybe that’s related). Anyway you were right, I make a circular reference (between the buffers and the object), so using weakref.ref in self.Conv_1_act = weakref.ref(out) solves the problem. I think it would be more natural to put the weak link in the lambda though, do you know if it’s possible?

Thank you very much for your help!

Hi,

I’m not sure you can bind arguments to a function or lambda in a weak way…

1 Like