Hey! Not sure if this is the right place to ask regarding torch_geometric, but will try anyway.
I am working with a large number of graphs that all have the same number of nodes. In all examples, the graphs are initialized using the Data structure. In my case, I often have a large number of graphs encoded as adjacency matrices in the format
(batch_size, n, n), however I am open to more memory efficient ways.
It seems like to only way to turn this into Data objects is to iterate over the batch dimension and create every Data object in a for-loop. This seems to be incredibly inefficient. How can I initialize all these graphs in a vectorized ways? It seems to be a waste to already have a large tensor and then iterating over it with a for-loop.
Any help is appreciated! Thanks!
Have you look at
torch.func.vmap? I’m not sure how it will interact with
torch_geometric, but there’s no harm in trying!
from torch.func import vmap
return x.pow(2) #example function
x = torch.randn(1000,6,6) #1000s 6x6 matrices
vectorized_output = vmap(f, in_dims=(0))(x)
#if f is nn.Module object, wrap the function in a torch.func.functional_call
Thanks for the fast reply, I have considered using
vmap and will try. Still, I expected that there is some native way of doing so, since this seems to be a classical use case - I thought I might be missing something.
A bit of context: I am trying to generate a lot of graphs repeatedly using a VGAE. When sampling from the latent distribution and decoding, torch_geometric will give me either the edge list or adjacency matrix (in probabilistic form). Clearly, I can decode a lot of graphs simultaneously using batched tensors. However, turning these into Graphs again seems to require calling
Data for each batch index, which seems like a complete bottleneck to my code.
One thing that comes to mind is if you can define your
data class as an object that takes the edges as an input rather than within the constructor.
For example, let’s say you had 10
nn.Module objects with different sets of parameters. You can wrap your network in
torch.func.functional_call to pass the parameters as an input, then just vmap over all sets. For all of the edges you have, are they all the same shape?