Backward hooks are called after being removed?

Hi,

I come across this situation that even though my backward hooks are cancelled by calling handle.remove(), they are still being triggered during backward pass.

import torch
import torch.nn as nn
import torch.nn.functional as F


def _make_encoder_layer(in_channels, out_channels, kernel_size=3):
    conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)
    pool = nn.MaxPool2d(2, stride=2)
    norm = nn.BatchNorm2d(out_channels)
    act = nn.ReLU()

    return nn.Sequential(conv, pool, norm, act)


def print_hook(module: nn.Module, _inputs, _outputs):
    print('hook triggered on', module)


class TestModule(nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
        self.encoder = _make_encoder_layer(1, 3)
        self.grad_handles = []

    def forward(self, x):
        self._register_backward_hooks()

        x.requires_grad = True
        h = self.encoder(x)

        print("Calling backward when hooks are enabled.")

        torch.autograd.backward(h.mean(), create_graph=True)
        # do something with grad
        blah = x.grad
        self.zero_grad()

        self._remove_backward_hooks()
        return h


    def _register_backward_hooks(self):
        # Iterate through layers
        for m in self.encoder.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d):
                handle_g = m.register_full_backward_hook(print_hook)
                self.grad_handles.append(handle_g)
                print(f"{m} registered.")
        print("")

    def _remove_backward_hooks(self):
        for h in self.grad_handles:
            h.remove()
        self.grad_handles = []

model = TestModule()
x = torch.rand(size=(32, 1, 128, 128))

p = model(x)

print("Calling backward after hooks are removed.")
p.mean().backward()

Running this code gives the following output:

Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) registered.
BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) registered.

Calling backward when hooks are enabled.
hook triggered on BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
hook triggered on Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Calling backward after hooks are removed.
hook triggered on BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
hook triggered on Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

Is this behaviour intended, or is this a bug?

1 Like