Using DDP for a large computation graph

Hi,
I’ve been trying to train a GNN with pytorch. My sample graphs can have 1-8 nodes.

Unfortunately, the computation graph is too large to fit inside the resources I have. I split up the computation for each node in the graph network using a for loop that computes/propagates the messages, calculates the loss & calls backward. Inside this loop, I call a couple of functions from my model, not forward(). This helps me fit the whole thing inside a Titan with 12GBs of memory but I have 2 and the second one has remained unused.

I couldn’t use DataParallel or DistributedDataParallel because the wrapped models only have the forward function. I could technically fit everything in a single forward function and call it multiple times instead of once, but sample graphs do not have the same number of nodes. e.g. I can’t call losses.backward() 3 times in one process and 5 times in the other.

Is there any other way for me to utilize both GPUs?
I appreciate any tips or advice.

Hey @hos-b does model parallelism work in your case? i.e., split your model into two shards and put each shard into one GPU.

If the two GPUs are on the same machine, here is a single-machine model parallel tutorial:

If the two GPUs are on different machines, here are some tutorials on RPC:

Hello @mrshenli. Thank you for your reply.

Model parallelism will probably also not work because even though the torch calls are async, I have to call backward() at the end of the for loop, creating a blocking bottleneck. The computation graph would otherwise get too large for a single GPU.

Your reply made me realize I had silently forgotten about DataParallel. Since my batches contain whole graphs, I somehow forgot I could further divide the graph into batches of size 2. I had stopped thinking about DataParallel when I found out about the advantages of DistributedDataParallel.

I see. Have you tried checkpointing some part of the autograd graph? This would drop the the activations and autograd graph in the forward and recompute them again in the backward. It’s like paying more compute to reduce memory footprint.

https://pytorch.org/docs/stable/checkpoint.html

I have to call backward() at the end of the for loop

Curious, is this because you need to compute the global loss across all iterations, so that you cannot run multiple fwd-bwd-fwd-bwd to accumulate grads (DDP can support this with no_sync() context manager) and then run optimizer.step() once to update parameters?

I did not know about checkpointing. It looks promising. Thank you

Sorry, I didn’t explain my training well. Each graph has up to 8 nodes. A full forward pass for a single node + the detached messages from the other nodes takes about 8GBs of memory. I go through the nodes using a for loop. I have to call backward() at the end of each iteration to calculate the loss, so it does get accumulated at each iteration. I do a single optimizer.step() after the loop to update the parameters. I technically do fwd-bwd x [1 - 8], into step().

From my understanding of DDP, the gradient reduction across multiple processes starts with the backward() call. That’s why I thought it wouldn’t work if different processes called it different amount of times. I did not know about no_sync() either. I’ll have to read about it. Thanks.

1 Like

here is the link: DistributedDataParallel — PyTorch master documentation

no_sync basically disables DDP comm within the context.

1 Like

Thanks. It took me some time to adapt the code to distributed training but no_sync() did the trick.

1 Like