I’m working with PyTorch-geometrics and heterogeneous graphs.
How can I aggregate the node embeddings by node types to receive the overall graph embedding?
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return global_mean_pool(x, batch=None)
model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
with torch.no_grad(): # Initialize lazy modules.
out = model(data.x_dict, data.edge_index_dict)
model(data.x_dict, data.edge_index_dict)
Result:
{'node_type_1': tensor([[-0.0687, -0.2087]], grad_fn=<MeanBackward1>),
'node_type_2': tensor([[ 0.0210, -0.0251]], grad_fn=<MeanBackward1>),
'node_type_3': tensor([[ 0.1923, -0.0379]], grad_fn=<MeanBackward1>),
'node_type_4': tensor([[-0.0792, 0.1417]], grad_fn=<MeanBackward1>)}