I’m implementing GradCAM now, but it seems there’s memory leak.
In GradCAM class, I registered hook like below
class GradCAM:
def __init__(self, module, target_layer):
self.module = module
self.target_layer = target_layer
self.target_output = None
self.target_output_grad = None
def forward_hook(_, __, output):
self.target_output = output.clone()
def backward_hook(_, __, grad_output):
assert len(grad_output) == 1
self.target_output_grad = grad_output[0].clone()
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_backward_hook(backward_hook)
and to get GradCAM result,
def get_grad_cam(self, image, target_class=None, counter=False):
out = self.forward_pass(image)
# ... skip code calculating onehot
out.backward(onehot)
grad = self.target_output_grad
grad = F.adaptive_avg_pool2d(grad, 1)
feature = self.target_output * grad
feature = F.relu(torch.sum(feature, dim=1))
#### del self.target_output, self.target_output_grad
return feature.squeeze().detach().cpu().numpy()
And this is the code using above GradCAM class.
def foo():
model = models.resnet101(pretrained=False)
target_layer = model.layer4[-1].conv3
cam = GradCAM(model, target_layer)
# ... skip code loading image
out = cam.get_grad_cam(image_t, target_class=-1, counter=True)
out = grad_utils.build_heatmap(out, size=image_sz_cv2)
del model, cam
return 'dummy_string'
However, when I run following code, there are some tensors not released even if all of there references are deleted.
def main():
print(foo())
gc.collect()
cnt = 0
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
cnt += 1
except:
pass
print('cnt', cnt)
The problem is gone when I add del self.target_output, self.target_output_grad
at the end of get_grad_cam function, but I don’t understand why such phenomenon happens.
In my thought, both self.target_output
, self.target_output_grad
and all other computational graphs should be released when foo()
is called at main()
, since no any reference exists fore them. Is there any point I missed?