Training a neural network with blocks of layers on different devices

Hello, I’m trying to train a neural network where the former part of the network (say block A) has to be hosted on a device, while the latter (say block B) on another device for privacy reasons. Is there a built-in way to accomplish this using PyTorch? I already looked at the accelerate, but was not able to find a solution to the problem.

Even if there is no built-in solution in PyTorch, how would you approach the problem? A practical use case may be a document image classification network (say LayoutLM) of medical diagnostic data, where the raw data cannot be moved outside the healthcare org VPC. Since the devices on the VPC are not well equipped for training, a suitable route would be to run only the former part of the network for text and image embedding generation on the devices in the VPC, and then run the latter, presumably for decoding, and the head on cloud devices (say Google Vertex AI).

Is there any way to accomplish this goal? For instance, is there a way to allow the back-propagation of gradients to flow though different devices along the depth of the network?

If there is not easy solution to the problem, how would you approach it? Even if I split the encoder and the decoder blocks of the original network between the VPC devices and the cloud devices, it is not clear to me how I could train the whole network in a end-to-end fashion assuming this is the only possible configuration of devices available (i.e., the full network does not fit in the VPC)

Your use case sounds like a Federated Learning use case where the data stays at the producers for privacy reasons. Did you already take a look at PySyft or other libs built on PyTorch for this use case?

Thanks @ptrblck,

I edited the response, since I just had a deeper look at the PySyft package. From my understanding, PySyft allows external parties to access privacy protected datasets exposed on some domain nodes, with privacy budget constraints on the availability of private information at their disposal. This is surely interesting, but would require a complete re-design of the architecture of the system the AI model should run through. Unfortunately, I do not have the power to control this. I also read the How-to Guides section of the documentation and from the provided (not-fully-complete) examples it seems that the dataset API is more suitable for tabular data, wherein privacy concerns are related to, e.g., the Id of patients in medical records. In our use case, we are working with images of documents submitted by the patients, thus we would need a way not to leak the content of said images to cloud devices.

I am not that experienced in federated learning, but are there any interesting resources that you could suggest on the topic? The least clear part to me is how we could perform e2e training of the two sub-networks, granted that the two networks will be hosted on different machines in different clusters and/or VPCs. The major deep learning frameworks allow different parameters of a network to be on different devices, so I would assume our problem is solvable, since it is almost the same idea, despite some devices being hosted on 3rd party cloud services.

Up!

Any other ideas?

The name is a bit of a misnomer but I think what you are after is Pipeline Parallelism. Using PyTorch’s rpc API, torch.distributed.rpc, it’s possible to co-ordinate training between multiple machines, where different parts of the networks are executed across the nodes.

If you’re able to connect to nodes outside the VPC and send other types of data, you could configure the master node inside the VPC, perform a small forward pass to produce intermediate results, passing these on to worker nodes outside the VPC to run more of the network, before receiving data back from worker nodes outside the VPC and then starting the distributed backward pass and parameter updates on the master node within the VPC.

It’s possible to configure this type of distributed model, however your networking constraints could make it more challenging and I can’t give much more guidance.

See:

Thank you @Jamie_Donnelly,

I didn’t know the Torch’s RPC API, but it definitely seems what I was looking for. I had a look at both your repository and the tutorial in Torch’s RPC API docs. There’s only one missing piece left that is not covered, i.e., how can pre-trained weights be loaded in such a formulation?

Maybe, I have to investigate more the behavior of remote() and to_here().

As you suggested, I split the original model in two shards, such that the constrained (VPC-protected) master node performs a small forward step, while the rest runs on powerful GPU worker nodes.

I agree, I think perform a small forward step on the master within the VPC before transferring out. You could initialise the subnetworks with the desired pre-trained weights with an rpc.remote call i.e.,

rpc.remote(workers[0],Subnetwork1,args=(pretrained_weights,))

And then within the class Subnetwork1 you just init the weights with the tensors arguments.

However, in that case, pretrained_weights would need to fit in the master node’s memory and then broadcasted to worker nodes, which in some cases might be a big communication overhead. Probably better off to load the weights from disk on the worker nodes, i.e., weights = torch.load(f'weights{rank}.pt'), or something similar.

Hey Jamie,

Instead of manually partitioning your model when checkpointing it, it might make more sense to use PyTorch’s distributed checkpoint API: Distributed Checkpoint - torch.distributed.checkpoint — PyTorch master documentation :slight_smile:

2 Likes

Hi @kumpera,

How would you combine the distributed checkpoint API with the distributed example Jamie detailed?

As for the initialization of pre-trained weights, @Jamie_Donnelly I followed your suggestion, and let every single worker load its shard of pre-trained weights from storage to avoid the start-up overhead.

I am closing the discussion (as I feel the original question has been answered), but @kumpera feel free to add more details about the distributed checkpoint API that you mentioned above.