Visualize Node Importance with Attention while doing Graph Embedding

I am quite new to the concept of attention. I am working with graph data and running graph convolution on it to learn node level embedding first. Then an attention layer to aggregate the nodes to learn a graph level embedding. Here is the setup:

graph->Conv1(Filter size 128)->Conv2-(Filter size 64>Conv3(Filter size 32) -> Attention -> Some other layers

After three convolution pass i get a matrix of size number_of_nodes_in_the_graph X 32 (embedding length). After the attention layer i get a flat vector representation of the graph with length 32. Here is the forward function of the attention module:

    def forward(self, embedding):
        """
        Making a forward propagation pass to create a graph level representation.
        :param embedding: Result of the GCN.
        :return representation: A graph level representation vector. 
        """
        global_context = torch.mean(torch.matmul(embedding, self.weight_matrix), dim=0)
        # print("Gloabal Context:", global_context.shape)

        transformed_global = torch.tanh(global_context)
        # print("transformed_global Context:", transformed_global.shape)

        sigmoid_scores = torch.sigmoid(torch.mm(embedding,transformed_global.view(-1,1)))
        # print("sigmoid_scores Context:", sigmoid_scores)

        representation = torch.mm(torch.t(embedding),sigmoid_scores)
        return representation, transformed_global

Now i would like to see which graph nodes were important for the final graph embedding. I am using pytorch for this. I cannot seem to figure out how to map the attention output to input. I would greatly appreciate any help! I am using this repo: https://github.com/benedekrozemberczki/SimGNN

1 Like

Hi @Sajjad_Hossain, Are you considering to apply attribution algorithms to the outputs with respect to the inputs in order to compute importance scores ? What kind of importance are you’re looking for ?

We have a library for model interpretability that allows to attribute the output target to the inputs or neurons of the network. Is that something that can be useful for you?

Hello @Narine_Kokhlikyan thank you very much for your reply. I am particularly not looking for attribution score but this is really interesting. I am just looking to understand the how the attention mechanism is playing the role here. Which specific variable (attention weight, activation, gradient or something else) will provide intuition about which node embedding vector contributed to the final graph embedding.