Forward and Backward hook

Hi,
I am working on visualizing the attention layers of the deit_tiny_patch16_224 model. To achieve this, I registered forward and backward hooks on the attn_drop layer using register_forward_hook and register_full_backward_hook. However, when I run the model, the hooks for the attn_drop layer are not being triggered.

Interestingly, if I register the same hooks on a different layer, such as k_norm, they work as expected.

I suspect this issue might be related to the dropout probability being set to 0.0 in the attn_drop layer. From my understanding, the dropout probability shouldn’t affect the execution of hooks, but I might be overlooking something. Could the disabled dropout (p=0.0) be causing the hooks to behave unexpectedly, or is there another reason the hooks aren’t working for this layer?

1 Like

Disabling dropout layers via p=0.0 won’t change the hook logic as seen here:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self, p):
        super().__init__()
        self.fc1 = nn.Linear(4, 4)
        self.drop = nn.Dropout(p)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.drop(x)
        return x

x = torch.randn(1, 4)

model = MyModel(p=0.5)
model.drop.register_forward_hook(lambda m, x, o: print("input {}, output {}".format(x, o)))
model.drop.register_full_backward_hook(lambda m, gi, go: print("ginput {}, goutput {}".format(gi, go)))

out = model(x)
# input (tensor([[-0.5069, -0.0430,  0.9238, -0.5910]],
#        grad_fn=<BackwardHookFunctionBackward>),), output tensor([[-1.0138, -0.0859,  1.8477, -1.1820]], grad_fn=<MulBackward0>)

out.mean().backward()
# ginput (tensor([[0.5000, 0.5000, 0.5000, 0.5000]]),), goutput (tensor([[0.2500, 0.2500, 0.2500, 0.2500]]),)


model = MyModel(p=0.0)
model.drop.register_forward_hook(lambda m, x, o: print("input {}, output {}".format(x, o)))
model.drop.register_full_backward_hook(lambda m, gi, go: print("ginput {}, goutput {}".format(gi, go)))

out = model(x)
# input (tensor([[ 0.1606, -0.6628, -0.7668, -0.4123]],
#        grad_fn=<BackwardHookFunctionBackward>),), output tensor([[ 0.1606, -0.6628, -0.7668, -0.4123]],
#        grad_fn=<BackwardHookFunctionBackward>)

out.mean().backward()
# ginput (tensor([[0.2500, 0.2500, 0.2500, 0.2500]]),), goutput (tensor([[0.2500, 0.2500, 0.2500, 0.2500]]),)

However, I don’t know if your implementation removes these layers to optimize the model.

1 Like