Pytorch Geometric - accessing global node embeddings

I’m trying to make a GNN where, after a few convolutions, it computes the dot product between one of the node embeddings and all the rest. This is to do link prediction on nodes that might not exist in the current graph. However, to my understanding Pytorch graphs only contain the node representations for their own nodes and not other nodes that exist somewhere else in the dataset.

How can you access node embeddings for nodes that don’t exist in the current graph?

# example
def forward(x, edge_index):
    x = conv1(x, edge_index)
    x = conv2(x, edge_index)

    # dot product vector x[0] with all other vectors x[1:]
    return x[0] @ x[1:]
    
# x[1:] only has the nodes contained in this particular input graph, but I want all the nodes in the entire dataset

Edit: Some clarification.

The initial node values are ints, and are stored in whatever data.x you pass into the forward() function. I’m assuming there is some sort of global torch.nn.Embedding layer that converts the node integers into the initial embeddings that self.conv1 acts upon. I am looking for these initial embeddings for all possible nodes.

Never mind, I found the answer. GCN layers have a simple Linear layer to project the input node representations into the right dimension. There dooes not seem to be an actual Embedding layer unless you add a custom one yourself. Here’s a snippet from GCNConv.

def __init__(self, in_channels: int, out_channels: int,
                 improved: bool = False, cached: bool = False,
                 add_self_loops: bool = True, normalize: bool = True,
                 bias: bool = True, **kwargs):

        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels

        # ... irrelevant stuff

        self.lin = Linear(in_channels, out_channels, bias=False,
                          weight_initializer='glorot')
        # ... irrelevant stuff


def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        if self.normalize:
                # ... irrelevant stuff

        x = self.lin(x) # <--- project input

        # ... irrelevant stuff