I have written a module for small graphs that has a custom collate_batch function so that I have batches of 128 graphs at a time however the graphs are different sizes and so we have a variable number of nodes. I am using torch_scatter for some operations.
My aim is to parallelise my code over multiple GPUs to reduce the real time for training. Naively using
nn.DataParallel(model) gives me an error as I believe it requires the sub-tensors it splits up the tensor into to all be the same size which is not the case here as while each GPU gets 64 graphs the number of nodes varies. Is there a way to do this currently across multiple GPUs or do I need to just take the hit and train on one GPU?