Batching with graphs of different sizes when building a GAT in Pytorch Geometric

I am building a GAT using PyTorch Geometric. I’ve manually built a GAT layer that uses only 1 attention head. However, I am stuck on figuring out how to get the predictions from the output of the final GAT layer. I have batches of size 32 graphs, and each graph can have a unique number of nodes (usually 20-30).

The output of my final GAT layer gives a tensor of shape (n_nodes, 1), where n_nodes is the number of nodes in the entire batch. Separately, I have the tensor batch, which provides indices identifying which individual graph each node is part of. My understanding is that I need to pass the nodes from each individual graph into a fully connected layer to obtain a logit for prediction. I will then get a tensor of 32 logits for the batch, which I will feed into the MSE loss function. However, since the graphs are of different sizes, the fully connected layer would need to have a dynamic input size. Am I thinking about this the right way, and what is the best solution for this?

Here is my GAT layer:

class GATLayer(nn.Module):
    def __init__(self, in_features, out_features, alpha=0.2, concat=True):
        super(GATLayer, self).__init__()
        self.in_features   = in_features
        self.out_features  = out_features
        self.alpha         = alpha # leakyrelu alpha, default 0.2
        self.concat        = concat # concat = True except for output layer.
        self.leakyrelu = nn.LeakyReLU(self.alpha)

        # Initialize weights with xavier method
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)

        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) ## attention matrix size out_features * 2 (one for each node in the pairing)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

    def forward(self, x, edge_index):
        # linear transformation on the feature vector for each node
        n_nodes = x.shape[0]
        x = x.type(torch.LongTensor)
        h = torch.mm(x.to(torch.float32), self.W.to(torch.float32)) # now we have 2 feature embeddings for each node

        # Apply attention mechanism
        a_input = torch.cat([
            h.repeat(1, n_nodes).view(n_nodes * n_nodes, -1),
            h.repeat(n_nodes, 1)
        ], dim=1).view(n_nodes, -1, 2 * self.out_features)
        leakyrelu = nn.LeakyReLU(0.2)
        e = leakyrelu(torch.matmul(a_input.to(torch.float32), self.a.to(torch.float32)).squeeze(2)) # shape (n_nodes, n_nodes) representing coefficient matrix across all nodes

        # Masked attention
        adj = torch.squeeze(to_dense_adj(edge_index)) # adjacency matrix
        zero_vec  = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)

        attention = F.softmax(attention, dim=1)
        h_prime = torch.matmul(attention.to(torch.float32), h.to(torch.float32))

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

and here is the model, with some pseudocode in the part I’m confused about:

class myGAT(nn.Module):
  def __init__(self, dim_in, dim_h, dim_out):
    '''
    Parameters
    -----------
    dim_in: int
      Size of each input sample.
    dim_h: int
      Size of input to hidden layer.
    dim_out: int
      Size of output.
    '''
    super(myGAT, self).__init__()
    self.gat1 = GATLayer(dim_in, dim_h)
    self.gat2 = GATLayer(dim_h, dim_out, concat=False)
    self.fc = nn.Linear(whatsize?, 1)
    self.optimizer = torch.optim.Adam(self.parameters(), lr=0.01)

  def forward(self, data):
    x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch

    h = self.gat1(x, edge_index)
    h = self.gat2(h, edge_index)

    # Here I need to get a prediction for each graph using the batch indices
    for i_graph in range(max(batch)+1):
      h_graph = h[batch==i_graph,:]
      output = self.fc(h_graph) # How to use a fc layer here?

    return (32 logits)


I realize there are existing functions like gatconv to do this, but I am trying to do a manual version as a learning exercise so I understand how the layer works. I appreciate any advice, thanks very much.