I am starting to use
nn.TransformerEncoder for some experiments and was wondering if there was a way to obtain the outputs and attention weights from intermediate layers?
For example, Table 7 in the BERT paper, studies the feature extraction capabilities of BERT and utilizes outputs from intermediate layers. The attention weights would generally help with analyzing the results.
Hence, I was wondering if there was a way to obtain them easily.
I’m not sure what you mean by
obtain is, but you can definitely inspect, dump the weights for any layer in the transformer. If you set up for transformer model something like this:
model = Transformer(args), you can print/access he intermediate weights like this
Of course the exact keys(
self_attn, etc) would change based on how you set up your model but it’s pretty straightforward.
I have a TransformerEncoder model that has been trained on a task similar to BERT’s masking. I would like to access, for example, the outputs of the 2nd to last layer while doing the inference step. Is the idea that I should take the trained model and extract some of the layers up until the last one I want and then do inference?
Hi @partially_observed and @EvanZ
If you’d like to extract input and/or output from some modules in your model, you can easily achieve it by using
ForwardHookManager in torchdistill.
Here is an example notebook to demonstrate the feature.
Thanks @yoshitomo-matsubara. I did end up using a hook via
register_forward_hook and that is working well.
Great to hear that you resolved it yourself.
ForwardHookManager in torchdistill does leverage the feature of forward hook in PyTorch like you did. It also takes care of managing
handle produced by forward hooks for you, so it should be worth taking a look next time.
Hey @EvanZ I want to visualize the attention over a sequence to draw conclusions from it. I use the nn.TransformerEncoder model. The only solution that seems to be available is using forward hooks. I was wondering if you could share how exactly you did this?
do you know what the reason could be if W_Q is None?
I initialized my Transformer-Encoder in the following way:
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
transformer = nn.TransformerEncoder(encoder_layer, num_layers)
other than that the network works as it should…