Hi,
I come across this situation that even though my backward hooks are cancelled by calling handle.remove()
, they are still being triggered during backward pass.
import torch
import torch.nn as nn
import torch.nn.functional as F
def _make_encoder_layer(in_channels, out_channels, kernel_size=3):
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)
pool = nn.MaxPool2d(2, stride=2)
norm = nn.BatchNorm2d(out_channels)
act = nn.ReLU()
return nn.Sequential(conv, pool, norm, act)
def print_hook(module: nn.Module, _inputs, _outputs):
print('hook triggered on', module)
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.encoder = _make_encoder_layer(1, 3)
self.grad_handles = []
def forward(self, x):
self._register_backward_hooks()
x.requires_grad = True
h = self.encoder(x)
print("Calling backward when hooks are enabled.")
torch.autograd.backward(h.mean(), create_graph=True)
# do something with grad
blah = x.grad
self.zero_grad()
self._remove_backward_hooks()
return h
def _register_backward_hooks(self):
# Iterate through layers
for m in self.encoder.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d):
handle_g = m.register_full_backward_hook(print_hook)
self.grad_handles.append(handle_g)
print(f"{m} registered.")
print("")
def _remove_backward_hooks(self):
for h in self.grad_handles:
h.remove()
self.grad_handles = []
model = TestModule()
x = torch.rand(size=(32, 1, 128, 128))
p = model(x)
print("Calling backward after hooks are removed.")
p.mean().backward()
Running this code gives the following output:
Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) registered.
BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) registered.
Calling backward when hooks are enabled.
hook triggered on BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
hook triggered on Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Calling backward after hooks are removed.
hook triggered on BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
hook triggered on Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Is this behaviour intended, or is this a bug?