Hi all,

I’m working on Graph Conv Network, each node has 2 features; I’m doing a regression model

I applied self-attention pooling in order to know which of these nodes contributes to the final prediction, every node gets a weight in which that would affect the final result

So, my question is how I can print/extract these weights in order to know/rank the importance of the nodes, and when exactly?

Thanks.

Here is my model:

class MyNet(nn.Module):

definit(self):

super(MyNet, self).init()

self.conv1 = GCNConv(2,5)

self.conv2 = GCNConv(5,8)

self.pool1= SAGPooling(8)

self.fc1 = nn.Linear(538136,3)

self.fc2 = nn.Linear(3,1)`def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, _, _, _ = self.pool1(x, edge_index, None, batch) b=data.y.shape[0] x=x.view(b,-1) x = F.relu(self.fc1(x)) x = self.fc2(x) x = torch.sigmoid(x) return x`