Incorrect hook being used in register_hook implementation

Hi, I’ve been trying to implement masking on gradients over some custom layer weights in my network, and I implement it in the following way (before start of each epoch/training step).

cur_hooks = {}    
for n,p in net.named_children():
    if isinstance(p, Conv2d) or isinstance(p, Linear):
            gradient_mask = (p.some_func(p.novel_para).round().float()<=1e-6).data
            gradient_mask = gradient_mask.float()
            cur_hooks[n] = p.weight.register_hook(lambda grad: grad.mul_(gradient_mask))

If I print the hooks via

        for k,v in cur_hooks.items():
            print(k, v)

It prints:

conv1 <torch.utils.hooks.RemovableHandle object at 0x7f554c92a1c0>
conv1a <torch.utils.hooks.RemovableHandle object at 0x7f554c92afa0>
conv1b <torch.utils.hooks.RemovableHandle object at 0x7f554c92a370>
conv1c <torch.utils.hooks.RemovableHandle object at 0x7f554c92a9a0>
conv1d <torch.utils.hooks.RemovableHandle object at 0x7f554c92aee0>
conv2a <torch.utils.hooks.RemovableHandle object at 0x7f554e05bd90>
conv2b <torch.utils.hooks.RemovableHandle object at 0x7f554e05bb80>
shortcut_conv2 <torch.utils.hooks.RemovableHandle object at 0x7f554e05be20>
conv2c <torch.utils.hooks.RemovableHandle object at 0x7f554e05bc70>
conv2d <torch.utils.hooks.RemovableHandle object at 0x7f554e05bd30>
conv3a <torch.utils.hooks.RemovableHandle object at 0x7f554c92f250>
conv3b <torch.utils.hooks.RemovableHandle object at 0x7f554c92f100>
shortcut_conv3 <torch.utils.hooks.RemovableHandle object at 0x7f554c92f9d0>
conv3c <torch.utils.hooks.RemovableHandle object at 0x7f554c92f430>
conv3d <torch.utils.hooks.RemovableHandle object at 0x7f554c92f880>
conv4a <torch.utils.hooks.RemovableHandle object at 0x7f554c92f2e0>
conv4b <torch.utils.hooks.RemovableHandle object at 0x7f554df74d90>
shortcut_conv4 <torch.utils.hooks.RemovableHandle object at 0x7f554df74f70>
conv4c <torch.utils.hooks.RemovableHandle object at 0x7f554df74cd0>
conv4d <torch.utils.hooks.RemovableHandle object at 0x7f554df74430>
linear <torch.utils.hooks.RemovableHandle object at 0x7f554df74f10>

However, at the very first backward() call, it runs into error-

Traceback (most recent call last):
  File "/mnt/", line 144, in <module>
  File "/home/user/base/lib/python3.9/site-packages/torch/", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/user/base/lib/python3.9/site-packages/torch/autograd/", line 130, in backward
  File "/mnt/", line 107, in <lambda>
    cur_hooks[n] = p.weight.register_hook(lambda grad: grad.mul_(gradient_mask))
RuntimeError: The size of tensor a (3) must match the size of tensor b (512) at non-singleton dimension 3

Apparently (If I’m not wrong) - its using the hook for the Linear layer (the last hook) for the first Convolution layer (conv1) - and I can’t seem to figure out what I did wrong. Any suggestions?

Are you getting the same shape mismatch error if you are using torch.ones_like(grad) instead of gradient_mask?
If not, then I assume the gradient_mask creation is creating the wrong shape.
If yes, could you post a minimal, executable code snippet, which would reproduce this error?

Here is a dummy script to recreate the error.

import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms

from torchvision.transforms import Compose, ToTensor, RandomAffine, RandomApply, Resize
import torch.nn.functional as F

import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = LeNet().to(device=device)

## train, test eval:
loss_metric = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), 0.01)

dataset = dset.FashionMNIST(root='/mnt/sda1/Data/', download=True, train=True,
                                transforms.Normalize((0.1307,), (0.3081,)),
valset = dset.FashionMNIST(root='/mnt/sda1/Data/', download=True, train=False,
                                transforms.Normalize((0.1307,), (0.3081,)),
train_loader =, batch_size=128,
                                                shuffle=True, num_workers=4)

# testloader =, batch_size=128,
#                                                 shuffle=False, num_workers=4)
def freeze_weights(model, hooks = []):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
            gradient_mask = (module.weight>.05).data
            gradient_mask = gradient_mask.float()
            h_ = module.weight.register_hook(lambda grad:print\
            ("Pre hook shape {}: non zero % {}, gradient mask shape {}".format\
            (grad.shape, torch.count_nonzero(grad)/torch.numel(grad), gradient_mask.shape)))
            h = module.weight.register_hook(lambda grad: grad.mul_(gradient_mask))
            h__ = module.weight.register_hook(lambda grad:print\
            ("Post hook shape {}: non zero % {}".format\
            (grad.shape, torch.count_nonzero(grad)/torch.numel(grad))))
            hooks = hooks + [h] + [h_] + [h__]
            print("{} gradient mask shape {} when created".format(name, gradient_mask.shape))
    if len(hooks)==0:
        raise Exception("hook list is empty!") # should be 15
    return hooks
for epoch in range(10):
    if epoch>0:
        curr_hooks = freeze_weights(net)
        print("{} hooks created".format(len(curr_hooks)))
    for i, (x,y) in enumerate(train_loader):

        x, y =,
        outputs = net(x)

        loss = loss_metric(outputs,y)
    if epoch>0:
        for h in curr_hooks: h.remove()
        curr_hooks = []

This is on Python 3.9.1 and torch version 1.7.1+cu110.

From what I understood, the gradient mask of the last layer/hook is being used in all of the hooks. I think I’d have to define a dict() for the gradient masks to keep track of them?

I think the solution is

def freeze_weights(model, hooks = [], gradient_mask_weights ={}):
    h = module.weight.register_hook(lambda grad, n=name: grad.mul_(gradient_mask_dict[n]))
    return hooks, gradient_mask_weights

Thanks for the code snippet.
The approach using a dict might work, alternatively, you could also use:

h = module.weight.register_hook(lambda grad, gradient_mask=gradient_mask: grad.mul_(gradient_mask))

Your current code runs into a known Python limitation of using lambda functions, since names in function bodies are evaluated when the function is executed.

1 Like

is there any difference between register_hook and register_backward_hook? both of them says registers a backward hook.
Whats the differnce

register_hook is used on a Tensor, while register_backward_hook is used on nn.Module, is deprecated, and should be replaced with register_full_backward_hook.

1 Like

I want to use activations an intermediate layer in the T5 decoder to compute a separate loss. Is there a way of adding a custom layer on an intermediate layer of a pretrained model and using the output of both the new layer head as well as the pretrained model head ?

I tried getting the activations of an intermediate layer of a pretrained model using hooks. But running into errors (sharing below).

More specifically I am unable to grab the value of activations. For eg, in the following code snippet, summarizer_decoder_activation is empty.

model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
summarizer_decoder_activation = {}
def get_activation(name):
  def hook(model, input, output):
      summarizer_decoder_activation[name] = output.detach()
  return hook

for name, layer in model._modules["decoder"]._modules["block"][-1]._modules["layer"][-1]._modules["DenseReluDense"].named_children():
# for name, layer in model.named_children():
  if name == "wo":

You are detaching the intermediate activation, so it’ll be a constant in the loss calculation.
Remove it and it should work:

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output
    return hook

model = models.resnet18()


x = torch.randn(2, 3, 224, 224)
out = model(x)
# torch.Size([2, 1000])
# torch.Size([2, 64, 56, 56])

loss = out.mean() + activation["aux_out"].mean()