How to extract/print the node weight

Hi all,

I’m working on Graph Conv Network, each node has 2 features; I’m doing a regression model

I applied self-attention pooling in order to know which of these nodes contributes to the final prediction, every node gets a weight in which that would affect the final result

So, my question is how I can print/extract these weights in order to know/rank the importance of the nodes, and when exactly?


Here is my model:

class MyNet(nn.Module):
def init(self):
super(MyNet, self).init()
self.conv1 = GCNConv(2,5)
self.conv2 = GCNConv(5,8)
self.pool1= SAGPooling(8)
self.fc1 = nn.Linear(538136,3)
self.fc2 = nn.Linear(3,1)

def forward(self, data):
    x, edge_index, batch = data.x, data.edge_index, data.batch
    x = F.relu(self.conv1(x, edge_index))
    x = F.relu(self.conv2(x, edge_index))
    x, edge_index, _, _, _, _ = self.pool1(x, edge_index, None, batch)

    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    x = torch.sigmoid(x)
    return x