I am currently working on a research project involving the application of a Graph Neural Network (GNN)-based model for simulations in a scientific domain. Model outputs a matrix based on the input graph. Usually one would then compare this output with true value and compute gradients. In my specific domain, I think, it might be useful to use this output to update the feature of initial graph, which can then again be used as an input to the model (I am using for loop for this, It’s kind of in a spirit of RNN but I am trying to take advantage of graph representation and message passing). Problem is, how do I parallelize this across multiple GPUs as it’s very slow?

Thanks in advance