nn.MultiheadAttention to get heatmap

Hello, I have a quetion about MultiheadAttention(short for MA). Not about the doc explaination, but is about using this module. I want to plot a heatmap(CAM) for my neural network based on transformer. In this process, I need to get the MA mid layer output, especially the dot product results for query-key pairs(similarity matrix). How can I get it? If can’t get it, I have to calculate the output dot product to estimate the result for the self attention layers. But this estimation may cause some errors. So do you have any idea to get the mid-layer results?
I want to use register_forward_hook, but this module architecture output really makes me confused cause it doesn’t show me the component layer that I need.

>>> print(self_attn)
MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)

So can you help me? Thank you very much!

I think you can use the returned weights from MultiheadAttention forward call to calculate the projections, and then calculate the attention matrix by yourself or use the ScaledDotProduct from torchtext library.

Hopefully it helps.

I am so sorry to omit to check the nn.MultiheadAttention.forward() API. Thank you very much. I got the solution!!!