How to adjust Linear Layers to Shapes of Adjacency Matrices?

Hi,

My task is to predict a label matrix that has the shape of the adjacency matrix, where each entry is predicted by computing a pair-wise dot product of node embeddings. I’m using torch_geometric library for such purpose. The problem is, I don’t know how to set the appropriate Linear output dimensions for the prediction task:

class PMIPredictor(pl.LightningModule):
    def __init__(self, **kargs):
        super().__init__()
        ...

        self.encoder = GSequential("x, edge_index, edge_attr", [
            (GATConv(feat_size, emb_size, heads=heads, edge_dim=edge_features),
             "x, edge_index, edge_attr -> x"),
            ReLU(),
            (GATConv(emb_size * heads, emb_size, heads=heads, edge_dim=edge_features),
             "x, edge_index, edge_attr -> x"),
            ReLU(),
            (GATConv(emb_size * heads, latent_size, heads=1, edge_dim=edge_features),
             "x, edge_index, edge_attr -> x"),
            ReLU()
        ])

        self.predictor = Sequential(
            Linear(latent_size * 2, decoder_size),
            ReLU(),
            # What is the appropriate size here?
            Linear(decoder_size, ?)
        )

    def encode(self, x, edge_attr, edge_index):
        return self.encoder(x, edge_attr, edge_index)

    def predict(self, z, label, batch_vector):
        z, mask = to_dense_batch(z, batch_vector)
        z_t = z.transpose(1, 2)
        # pair-wise node embeddings of size [batch, max_node, max_node], max_node is taken from
        # biggest graph of the batch!
        pairwise_emb = z @ z_t
        # I want to keep the size [batch, max_node, max_node] here
        return logits, mask

    def training_step(self, batch, batch_index):
        z = self.encode(batch.x,  batch.edge_attr, batch.edge_index)
        # labels are already of size [batch, max_node, max_node]
        labels = batch.y
        logits, masks = self.predict(z, batch.batch)
        loss = self.bce(logits[mask], labels[mask])
        ...

How could I keep the size [batch, max_node, max_node] with an appropriate size of the final Linear layer? Thank you.