I’m trying to make a GNN where, after a few convolutions, it computes the dot product between one of the node embeddings and all the rest. This is to do link prediction on nodes that might not exist in the current graph. However, to my understanding Pytorch graphs only contain the node representations for their own nodes and not other nodes that exist somewhere else in the dataset.
How can you access node embeddings for nodes that don’t exist in the current graph?
# example def forward(x, edge_index): x = conv1(x, edge_index) x = conv2(x, edge_index) # dot product vector x with all other vectors x[1:] return x @ x[1:] # x[1:] only has the nodes contained in this particular input graph, but I want all the nodes in the entire dataset
Edit: Some clarification.
The initial node values are ints, and are stored in whatever
data.x you pass into the
forward() function. I’m assuming there is some sort of global
torch.nn.Embedding layer that converts the node integers into the initial embeddings that
self.conv1 acts upon. I am looking for these initial embeddings for all possible nodes.