Interpreting Attention matrix weights in MultiheadAttention

Hi Pytorch Universe,
I have an sequences of human’s poses from multiple angles and ground true pose class. So my data tensor is in form of [1, 97, 1, 128, 128], where dims are [batch_size, timepoint, channel, height, width].
Class is only related to whole sequence, not each timepoint. Therefore, this task could be assumed as weakly-supervised multi-instance learning.

I build a transformer encoder with a CNN backbone, without class token and positioning encoding. First input tensor is embedded to [1,97,128] (97 tokens of 128 embedding vector) after backbone CNN. Then it rearraranged to [97,1,128] and forwarded to Transformer encoder with pytorch’s nn.MultiheadAttention(128, 8, dropout).
What I am trying to understand is - I acquire the attention matrix with shape [1,8,97,97], where 8 is number of head, 97 - is number if tokens (timepoints). How can interpret the attention matrix? Is it possible to create a heatmap to overlay on timepoints of original tensor [1, 97, 1, 128, 128]?

Normally, attention matrices show which tokens in the sequence(in your case, time steps) are most relevant to solving the task. And what usually gets highlighted is along the main diagonal, with occasional spots else where, indicating a reordering of tokens may make more sense in decoding(i.e. translation).

But with your attention matrices, there appear to be primarily vertical lines. It might mean those brighter sections (time steps) are what the model thinks it should focus on to classify the sequence. That’s assuming it’s been sufficiently trained.

Hey @J_Johnson,
thank for your reply. In the case of NLP example you provided, the attention matrix, is a vector dot product of Q and V, and, as a measure of similarity, shows how token representation is similar to it’s target. However, I still struggle to understand, what the attention matrix shows in my case. So the there is two matrices Q and V acquired by linearly mapping input tensor to Q K V matrices, duplicated with weights n_head times. So the Q and V in each head are just two input tensor projections. Then attention matrix is computed with dot product of Q and V. As I understood, the dot product is computed between two linear projections of the input tensor. If I look at the attention matrices above, Head 1 Layer 1, have high weight values in around 10/97 sequences, it means that we have high similarity between vector in V with Q? Does my intuition right?
Why the pattern is mostly vertical?

If you’re using the PyTorch Transformer, then it is self-attention. And so it is simply determining which time steps are most relevant.

The vertical lines are a bit odd and I don’t know what to tell you, in that case. How is the performance of the model after training?

If the performance is fitting to the data, it’s possible that the vertical is mapping the time steps while the horizontal is some type of mapping to the output probabilities.

The performance of multi-class classification, in terms of accuracy is more than good: training, test and hold-out validation are > 0.94 ± 0.148.
Thanks for your thoughts