Pytorch Transformer doesn't work with register_forward_pre_hook

I have a testing code as follows, and it only prints msg from hook_fn registered with Linear.

Is there a reason why hook doesn’t work with _LinearWithBias in MultiheadAttention?


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class SimpleVT(nn.Module):

def __init__(self):
    super().__init__()
    self.enc_layers = TransformerEncoderLayer(40, 2, 20, 0.5)
    self.encoder = TransformerEncoder(self.enc_layers, 2)
    self.decoder = nn.Linear(40, 2)
    
def forward(self, x):
    x = self.enc_layers(x)
    x = self.encoder(x)
    x = self.decoder(x)
    return x

model = SimpleVT().cuda()
print(model)

def forward_pre_hook_fn(module, i):
print(“forward_pre_hook_fn”, type(module))

print(type(model.enc_layers.self_attn.out_proj))
model.enc_layers.self_attn.out_proj.register_forward_pre_hook(forward_pre_hook_fn)
model.enc_layers.linear1.register_forward_pre_hook(forward_pre_hook_fn)

x = torch.rand([64, 102, 40]).cuda()
y_hat = torch.rand([64, 102, 2]).cuda()

y = model(x)


There is no print from the hook registered with model.enc_layers.self_attn.out_proj.

I get a valid output:

y = model(x)
forward_pre_hook_fn <class 'torch.nn.modules.linear.Linear'>

Also, based on the screenshot it seems you are seeing the same output.

PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier. :wink:

Hi, @ptrblck, Actually, it is not correct. I have registered the same hook_fn on two modules.

model.enc_layers.self_attn.out_proj.register_forward_pre_hook(forward_pre_hook_fn)
model.enc_layers.linear1.register_forward_pre_hook(forward_pre_hook_fn)

What we are seeing is from the 2nd call, and the 1st one doesn’t work: we should see another print. Hope this makes clear.

Thanks for the clarification. I missed this output.
The layer itself is never called, but its parameters are used in F.multi_head_attention_forward in these lines on code, which is why the hook isn’t called.

oh, I see from the implementation, thanks. For this particular case, I wonder why it is done that way, but I guess it was the design choice. Is there a way to detect this case (there is a child module, but never called in forward) systematically? Or should it be a case-by-case thing?

I don’t know why this approach was chosen and I don’t know if there is another way of checking, if the forward method was called besides what you’ve already did: using hooks and checking their output.