Adding Distributed Model Parallelism to PyTorch

Hi All,

I am a researcher in LBL interested in implementing distributed model parallelism in PyTorch. This could in fact be useful for our research as well. Currently, I am looking at the DistributedDataParallel classes to see how PyTorch decomposes data internally across machines.

I wonder if the PyTorch community would be interested in this and if there’s already some work on this topic.

Thank you,

I cannot speak for the community, but I would be interested in and probably make use of any model parallelism in PyTorch, especially as pertains to RNN variants.

That would be a handy tool.

I’m sure you’ve seen, e.g.:

Then to do it in a distributed setting, I guess you would “supply” the model on node t with the output of node t-1 (where t=1 gets the “actual” inputs). I put supply in quotes because you would need to do something like torch.distributed.receive(tensor=<whatever>, dst=node_idx-1). Then chunk up your forward to only run the parts of the model you’ve designated for node t. For the “return,” do something like a torch.distributed.send(...). The idea in my head requires you have have parts of your forward cordoned off with if section_num % num_nodes == node_idx.

I imagine this has already occurred to you. I illustrate it to make the point that I have no idea how you could do the necessary load balancing at that level automatically.

Maybe you could define functions forward_<i>, and have some superclass/wrapper inheriting from nn.Module stitch them together for forward and backward. However, I do not know how pointwise ops (as opposed to collective) work with autograd.

But that’s only part of the problem. Usually when you need model parallel, the actual limitation is size of the model, not the forward pass. You need to do the same sort of thing with the __init__ function. Maybe the way the user chunks the __init__ could be used to infer the forward_<i>'s from the forward function.

Even after all that, the problem in my head is saving the state dict when some of the weights are on some machine and others on others. Sure, you can broadcast the weights if a save is requested, but wouldn’t you need to have them all at one location to pickle them? At that point, you’ve lost your space savings…

I’m interested in hearing what you have in mind. Maybe I’m overthinking this.

Thank you, Dylan for the response. Yes, I’ve see the two questions.

I’ve been quiet on this thread mainly because I’ve been trying to dig into PyTorch code. If I understood correctly, what you have mentioned in the beginning is about splitting up layers into multiple machines. What I was thinking was to split each layer into multiple machines. The reason being that if you split across layers then the machine handling layer i would anyway be idle until layer i-1 is completed, so there’s not much gain from parallelizing the model. On the other hand, if we split a layer into multiple machines, then we can utilize parallelism better. Then once all the computations of a layer is done we can sync (allgather) the outputs before starting the next layer.

I am new to PyTorch internals, so would appreciate any help on figuring out the code. I was looking at this post (The pytorch blog "A Tour of PyTorch Internals" is out-of-date. How to know more about the pytorch internal). Is there anything else you can recommend?

Simon Wang’s response there is a very good “quick and high-level description.”

Follow Peter Goldsborough’s tutorial for a hands-on introduction to interfacing between ATen and Python. The way the PyTorch source interfaces isn’t exactly the same, but the tutorial will get you acquainted with ATen.

If you want to dig further, this blog post has a good tour of the internals of ATen.

Do be aware that once you get to the T*C level, you’re up against daily changes (see goldsborough’s first reply on this thread).