Hi,
It seems like there is a bug with the new register_full_backward_hook
method that causes memory leak while the old register_backward_hook
doesn’t have the problem.
The following code snippet uses register_full_backward_hook
.
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_memlab import MemReporter
def _make_encoder_layer(in_channels, out_channels, kernel_size=3):
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)
pool = nn.MaxPool2d(2, stride=2)
norm = nn.BatchNorm2d(out_channels)
act = nn.ReLU()
return nn.Sequential(conv, pool, norm, act)
def hook_func(module: nn.Module, _inputs, _outputs):
pass
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.encoder = _make_encoder_layer(1, 3)
self.grad_handles = []
def forward(self, x):
self._register_backward_hooks()
x.requires_grad = True
h = self.encoder(x)
torch.autograd.backward(h.mean(), create_graph=True)
# do something with grad
blah = x.grad
self.zero_grad()
self._remove_backward_hooks()
return h
def _register_backward_hooks(self):
# Iterate through layers
for m in self.encoder.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d):
handle_g = m.register_full_backward_hook(hook_func)
self.grad_handles.append(handle_g)
def _remove_backward_hooks(self):
for h in self.grad_handles:
h.remove()
self.grad_handles = []
model = TestModule()
reporter = MemReporter(model)
for i in range(10):
x = torch.rand(size=(32, 1, 128, 128))
p = model(x)
p.mean().backward()
if i == 0 or i == 9:
print("")
print(f"On the {i+1}th iteration.")
print("")
reporter.report()
pytorch_memlab
is used to inspect the memory usage of tensor variables. The above snippet gives the following output:
On the 1th iteration.
Element type Size Used MEM
-------------------------------------------------------------------------------
Storage on cpu
Tensor0 (32, 1, 128, 128) 2.00M
Tensor1 (1,) 512.00B
Tensor2 (32, 3, 64, 64) 1.50M
Tensor3 (32, 3, 64, 64) 1.50M
Tensor4 (32, 3, 128, 128) 6.00M
Tensor5 (3,) 512.00B
Tensor6 (3,) 512.00B
encoder.0.weight (3, 1, 3, 3) 512.00B
encoder.0.weight.grad (3, 1, 3, 3) 512.00B
encoder.0.bias (3,) 512.00B
encoder.0.bias.grad (3,) 512.00B
encoder.2.weight (3,) 512.00B
encoder.2.weight.grad (3,) 512.00B
encoder.2.bias (3,) 512.00B
encoder.2.bias.grad (3,) 512.00B
-------------------------------------------------------------------------------
Total Tensors: 2883663 Used Memory: 11.01M
-------------------------------------------------------------------------------
On the 10th iteration.
Element type Size Used MEM
-------------------------------------------------------------------------------
Storage on cpu
Tensor3 (32, 3, 64, 64) 1.50M
Tensor4 (32, 3, 128, 128) 6.00M
Tensor7 (32, 3, 64, 64) 1.50M
Tensor8 (32, 3, 128, 128) 6.00M
Tensor9 (32, 3, 64, 64) 1.50M
Tensor10 (32, 3, 128, 128) 6.00M
Tensor11 (32, 3, 64, 64) 1.50M
Tensor12 (32, 3, 128, 128) 6.00M
Tensor13 (32, 3, 64, 64) 1.50M
Tensor14 (32, 3, 128, 128) 6.00M
Tensor15 (32, 3, 64, 64) 1.50M
Tensor16 (32, 3, 128, 128) 6.00M
Tensor17 (32, 3, 64, 64) 1.50M
Tensor18 (32, 3, 128, 128) 6.00M
Tensor19 (32, 3, 64, 64) 1.50M
Tensor20 (32, 3, 128, 128) 6.00M
Tensor21 (32, 3, 64, 64) 1.50M
Tensor22 (32, 3, 128, 128) 6.00M
Tensor23 (32, 1, 128, 128) 2.00M
Tensor24 (1,) 512.00B
Tensor25 (32, 3, 64, 64) 1.50M
Tensor26 (32, 3, 64, 64) 1.50M
Tensor27 (32, 3, 128, 128) 6.00M
Tensor5 (3,) 512.00B
Tensor6 (3,) 512.00B
encoder.0.weight (3, 1, 3, 3) 512.00B
encoder.0.weight.grad (3, 1, 3, 3) 512.00B
encoder.0.bias (3,) 512.00B
encoder.0.bias.grad (3,) 512.00B
encoder.2.weight (3,) 512.00B
encoder.2.weight.grad (3,) 512.00B
encoder.2.bias (3,) 512.00B
encoder.2.bias.grad (3,) 512.00B
-------------------------------------------------------------------------------
Total Tensors: 20578383 Used Memory: 78.51M
-------------------------------------------------------------------------------
Somehow there are a lot of tensors that are not recycled by the garbage collection mechanism and increasing the number of loops will eventually cause the run out of memory RuntimeError.
However, switching back to the old method register_backward_hook
solves this issue.