Hi all,
I’m working on Graph Conv Network, each node has 2 features; I’m doing a regression model
I applied self-attention pooling in order to know which of these nodes contributes to the final prediction, every node gets a weight in which that would affect the final result
So, my question is how I can print/extract these weights in order to know/rank the importance of the nodes, and when exactly?
Thanks.
Here is my model:
class MyNet(nn.Module):
def init(self):
super(MyNet, self).init()
self.conv1 = GCNConv(2,5)
self.conv2 = GCNConv(5,8)
self.pool1= SAGPooling(8)
self.fc1 = nn.Linear(538136,3)
self.fc2 = nn.Linear(3,1)def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, _, _, _ = self.pool1(x, edge_index, None, batch) b=data.y.shape[0] x=x.view(b,-1) x = F.relu(self.fc1(x)) x = self.fc2(x) x = torch.sigmoid(x) return x