I’m trying to make a GNN where, after a few convolutions, it computes the dot product between one of the node embeddings and all the rest. This is to do link prediction on nodes that might not exist in the current graph. However, to my understanding Pytorch graphs only contain the node representations for their own nodes and not other nodes that exist somewhere else in the dataset.
How can you access node embeddings for nodes that don’t exist in the current graph?
# example
def forward(x, edge_index):
x = conv1(x, edge_index)
x = conv2(x, edge_index)
# dot product vector x[0] with all other vectors x[1:]
return x[0] @ x[1:]
# x[1:] only has the nodes contained in this particular input graph, but I want all the nodes in the entire dataset
Edit: Some clarification.
The initial node values are ints, and are stored in whatever data.x
you pass into the forward()
function. I’m assuming there is some sort of global torch.nn.Embedding
layer that converts the node integers into the initial embeddings that self.conv1
acts upon. I am looking for these initial embeddings for all possible nodes.