Adding Distributed Model Parallelism to PyTorch

That would be a handy tool.

I’m sure you’ve seen, e.g.:

Then to do it in a distributed setting, I guess you would “supply” the model on node t with the output of node t-1 (where t=1 gets the “actual” inputs). I put supply in quotes because you would need to do something like torch.distributed.receive(tensor=<whatever>, dst=node_idx-1). Then chunk up your forward to only run the parts of the model you’ve designated for node t. For the “return,” do something like a torch.distributed.send(...). The idea in my head requires you have have parts of your forward cordoned off with if section_num % num_nodes == node_idx.

I imagine this has already occurred to you. I illustrate it to make the point that I have no idea how you could do the necessary load balancing at that level automatically.

Maybe you could define functions forward_<i>, and have some superclass/wrapper inheriting from nn.Module stitch them together for forward and backward. However, I do not know how pointwise ops (as opposed to collective) work with autograd.

But that’s only part of the problem. Usually when you need model parallel, the actual limitation is size of the model, not the forward pass. You need to do the same sort of thing with the __init__ function. Maybe the way the user chunks the __init__ could be used to infer the forward_<i>'s from the forward function.

Even after all that, the problem in my head is saving the state dict when some of the weights are on some machine and others on others. Sure, you can broadcast the weights if a save is requested, but wouldn’t you need to have them all at one location to pickle them? At that point, you’ve lost your space savings…

I’m interested in hearing what you have in mind. Maybe I’m overthinking this.