Hello,
in order to implement channel pruning, I register the output of activations in my network as buffers and perform a register_hook on them. Doing so produces memory leaks after deleting the model unless the registered buffers are freed (self.buffer = None) at the end of the hook function. Freeing the buffers at the end of the hook works in my use-case so the problem is kinda solved but I don’t understand why it helps.
I use torch version 1.0.1.post2.
Here is a dummy standalone code sample that reproduces the problem (requires to download CIFAR-10). I delete and create the model again inside of a loop and print the memory usage at the end of the loop. The memory usage increases after each iteration unless the hook function clears the buffer (commented line in _fisher method).
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.Conv_1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
self.bn_1 = nn.BatchNorm2d(32)
self.relu_1 = nn.ReLU(inplace=True)
self.register_buffer("Conv_1_run_fish", None)
self.Conv_2 = nn.Conv2d(32, 320, kernel_size=3, stride=2, padding=1, bias=False)
self.bn_2 = nn.BatchNorm2d(320)
self.relu_2 = nn.ReLU(inplace=True)
self.register_buffer("Conv_2_run_fish", None)
self.avg_pool = nn.AvgPool2d(8)
self.linear = nn.Linear(320, 10)
def _fisher(self, grad_output, act_name, running_fisher_name):
# setattr(self, act_name, None) # comment to get rid of memory leak
pass
def forward(self, x):
out = self.relu_1(self.bn_1(self.Conv_1(x)))
out.register_hook(lambda x: self._fisher(x, "Conv_1_act", "Conv_1_run_fish"))
self.Conv_1_act = out
out = self.relu_2(self.bn_2(self.Conv_2(out)))
out.register_hook(lambda x: self._fisher(x, "Conv_2_act", "Conv_2_run_fish"))
self.Conv_2_act = out
out = self.avg_pool(out)
out = out.view(-1, 320)
return self.linear(out)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = torchvision.datasets.CIFAR10(root='~/Documents/CIFAR-10', train=True, download=True,
transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0, pin_memory=False)
for i in range(30):
model = Network()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad],
lr=8e-4, momentum=0.9, weight_decay=0.0005)
model.train() # switch to train mode
for batch_in, batch_target in dataloader:
batch_in, batch_target = batch_in.to(device), batch_target.to(device)
# compute output
output = model(batch_in)
loss = criterion(output, batch_target)
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
del optimizer
del model
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated() / 10 ** 9)