Torch_geometric: How can I create a lot of graphs efficiently?

Hey! Not sure if this is the right place to ask regarding torch_geometric, but will try anyway.

I am working with a large number of graphs that all have the same number of nodes. In all examples, the graphs are initialized using the Data structure. In my case, I often have a large number of graphs encoded as adjacency matrices in the format (batch_size, n, n), however I am open to more memory efficient ways.

It seems like to only way to turn this into Data objects is to iterate over the batch dimension and create every Data object in a for-loop. This seems to be incredibly inefficient. How can I initialize all these graphs in a vectorized ways? It seems to be a waste to already have a large tensor and then iterating over it with a for-loop.

Any help is appreciated! Thanks!

Hi @mzimmer,

Have you look at torch.func.vmap? I’m not sure how it will interact with torch_geometric, but there’s no harm in trying!

For example,

from torch.func import vmap

def f(x):
  return x.pow(2) #example function

x = torch.randn(1000,6,6) #1000s 6x6 matrices

vectorized_output = vmap(f, in_dims=(0))(x) 
#if f is nn.Module object, wrap the function in a torch.func.functional_call
1 Like

Thanks for the fast reply, I have considered using vmap and will try. Still, I expected that there is some native way of doing so, since this seems to be a classical use case - I thought I might be missing something.

A bit of context: I am trying to generate a lot of graphs repeatedly using a VGAE. When sampling from the latent distribution and decoding, torch_geometric will give me either the edge list or adjacency matrix (in probabilistic form). Clearly, I can decode a lot of graphs simultaneously using batched tensors. However, turning these into Graphs again seems to require calling Data for each batch index, which seems like a complete bottleneck to my code.

One thing that comes to mind is if you can define your data class as an object that takes the edges as an input rather than within the constructor.

For example, let’s say you had 10 nn.Module objects with different sets of parameters. You can wrap your network in torch.func.functional_call to pass the parameters as an input, then just vmap over all sets. For all of the edges you have, are they all the same shape?