I am quite new to the concept of attention. I am working with graph data and running graph convolution on it to learn node level embedding first. Then an attention layer to aggregate the nodes to learn a graph level embedding. Here is the setup:
graph->Conv1(Filter size 128)->Conv2-(Filter size 64>Conv3(Filter size 32) -> Attention -> Some other layers
After three convolution pass i get a matrix of size number_of_nodes_in_the_graph X 32 (embedding length). After the attention layer i get a flat vector representation of the graph with length 32. Here is the forward function of the attention module:
def forward(self, embedding):
"""
Making a forward propagation pass to create a graph level representation.
:param embedding: Result of the GCN.
:return representation: A graph level representation vector.
"""
global_context = torch.mean(torch.matmul(embedding, self.weight_matrix), dim=0)
# print("Gloabal Context:", global_context.shape)
transformed_global = torch.tanh(global_context)
# print("transformed_global Context:", transformed_global.shape)
sigmoid_scores = torch.sigmoid(torch.mm(embedding,transformed_global.view(-1,1)))
# print("sigmoid_scores Context:", sigmoid_scores)
representation = torch.mm(torch.t(embedding),sigmoid_scores)
return representation, transformed_global
Now i would like to see which graph nodes were important for the final graph embedding. I am using pytorch for this. I cannot seem to figure out how to map the attention output to input. I would greatly appreciate any help! I am using this repo: https://github.com/benedekrozemberczki/SimGNN