Discrepancy between number of edges in batched samples and single sample in Pytroch Geometric

Hi, I’m having difficulties with seeing why batched graphs have a different number of edges than they should. I am using the NeuroGraphDataset and trying to train a graph diffusion model. In this dataset there 400 nodes in each graph, and the node features are correlation vectors with the value of other nodes (Brain fMRI ROI values). This is relevant in that the edges for the graphs were calculated by taking the top 5% of values in the (400, 400) feature matrix made by the node feature vectors. Therefore the number of edges are different for each graph.

Here’s where my problem comes in. When I calculate the edges with the same method for each node in a batch of 8, and concatenate the results into a new edge_index, it has much fewer edges than what comes from the dataset. The method for calculating the edges from the feature matrix is correct, I’ve verified it by calculating it for each sample in the dataset, and checking if the computed edges are they same as the ones in the dataset, and indeed they are.

Here’s a code snippet to give you a better picture:

def forward(self, x, edge_index, timestep, **kwargs):
      new_computed_edge_indices = []
      for i in range(len(kwargs["y"])): # Number of class labels to determine batch_size is clumsy but correct.
          computed_edges = self.compute_edges(x[i * (400) : (i + 1) * 400]) # Go though each graph in the batch and calculate edges
          computed_edges += i * 400 # Offset computed edge indices
          new_computed_edge_indices.append(computed_edges)

      concat_new_indices = torch.cat(new_computed_edge_indices, axis=1)

# Shape of edge_index is (2, 56544)
# Shape of concat_new_indices is (2, 36510), when it should be (2, 56544).
# ...

def compute_edges(self, corr, threshold=5):
    """construct adjacency matrix from the given correlation matrix and threshold. Taken from NeuroGraph code: https://github.com/Anwar-Said/NeuroGraph/blob/main/NeuroGraph/preprocess.py#L274
    Threshold is set to top 5%, for HCPTask dataset."""
    corr_matrix_copy = corr.detach().clone().to("cpu")
    threshold = np.percentile(
        corr_matrix_copy[corr_matrix_copy > 0], 100 - threshold
    )
    corr_matrix_copy[corr_matrix_copy < threshold] = 0
    corr_matrix_copy[corr_matrix_copy >= threshold] = 1
    return corr_matrix_copy.nonzero().t().to(torch.long).to(corr.device)

I’m fairly new to PyG, so maybe there is something about batching I don’t understand. I’ve read the advanced mini-batching page, but I don’t see anything that would lead to having more edges, only the indices change due to the offsets, not the number of edges.

The only thing I could think of is maybe the collate function does something else, that I’m not aware of? Is there something obvious I’m missing here? I would appreciate any help or explanation on why this could be happening, and how to correctly compute the edges from the feature matrix.