nn.Multi-Head Attention, Analyzing weights


i’m using nn.MulitheadAttention layer from the paper attention is all you need to create an attended graph node embedding. My goal is to created a new embedding which contains best elements of multiple embeddings. Im posting this as would like to know if use the layer correctly (although the results are good). And because I don’t understand how to interpret the attention weights of the heads that I would like to analyze.

I interpreted the heads a bit as a kernel and i use 3 heads with the ides to assign each embedding a head. Does this choice make sense? I tried to play around with the number of heads and increasing and decreasing the heads led to a worse result.

What Im passing to the multi-head attention:
-I constructed 3 node embeddings of size (number of nodes, embedding dim).
-I stack these into 3d tensor → embedding.size = (3, number of nodes, embedding dim).
-the stacked 3d tensor does not have grads (frozen).
-For Q, K, and V in the multihead attention I pass the same 3d ‘embedding’ Tensor.

Output of multihead attention:
The output embedding is the of size (3, number of nodes, embedding dim).
After some tests, Im using the tensor on index 0 as input for my RGCN.
This works and as I said the results are good. But does this make any sense? If you pass a 3d tensor to the multihead is the tensor on index 0 the attended tensor? Or will the multihead attention layer learn that the Tensor on index 0 should contain the attended values.

How to analyze the attention weights?
-I would like to know how the tensor on index 0 is constructed:
-Does the new constructed embedding tensor mainly contain values from one of the stacked embedding Tensors?
-The weight matrices of the attention head become pretty large, so averaging the heads seems a nice option, but will you loose the ability to determine if some of the input tensors is attended more than others?

All replies appreciated!

These are the layers Im using

class Emb_ATT_Layers(nn.Module):
    def __init__(self, num_relations: int, hidden_l: int, num_labels: int, _, emb_dim: int, num_embs: int) -> None:
        super(Emb_ATT_Layers, self).__init__()
        self.embedding = None
        self.att = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_embs, dropout=0.2)
        self.rgcn1 = RGCNConv(in_channels=emb_dim, out_channels=hidden_l, num_relations=num_relations, num_bases=None)
        self.rgcn2 = RGCNConv(hidden_l, num_labels, num_relations, num_bases=None)
        nn.init.kaiming_uniform_(self.rgcn1.weight, mode='fan_in')
        nn.init.kaiming_uniform_(self.rgcn2.weight, mode='fan_in')

    def forward(self, training_data: Data, activation: Callable) -> Tensor:
        attn_output, att_weights = self.att(self.embedding, self.embedding, self.embedding, average_attn_weights=True)
        # print(att_weights)
        x = attn_output[0]
        x = self.rgcn1(x, training_data.edge_index, training_data.edge_type)
        x = F.relu(x)
        x = self.rgcn2(x, training_data.edge_index, training_data.edge_type)
        x = activation(x)
        return x