Suggested design for multiprocess federated learning

Good afternoon,
I have a federated learning codebase that runs the simulation sequentially. I would like to use pytorch multiprocessing or distribute packages to be able to leverage multiple GPUs and (optionally) multiple nodes.

Brief description of my use case (FL)
A typical FL simulation consist of several rounds of training. At each round, a subset of compute entities (clients) run several optimization steps on a private local model and private data. At the end of each round, the gradients of each client are merged (e.g. weighted average) and used to update a global model, which is then used as initialization model for the clients selected at the next round.

The code I start from
The code I use is built around three classes:

  • CenterServer: contains the global model, the aggregation logic of local models and testing code
  • Algo: contanins the main loop, in which the global model is sent to client, they perform local training and then the trained are sent to server for aggregation
  • Client: contains a dataloader and a copy of the model.

The CenterServer object and all the Client objects are all created inside Algo (e.g. all the datasets and dataloaders).

What I want to do
At the moment the train loop sequentially trains all the clients, I would like to leverage multiple GPUs or, if possible, also multiple training nodes. I am asking for some advice about what is the best way to port my codebase so that I can achieve that.

  • Solution 1: use processes
    I could use processes to make each client perform its local optimization in parallel. However I have some doubts about how the memory is used. Le it be f the function passed to Process:

    • Do the args passed to f get copied? For example, the client object contains a dataloader, and most of my datasets are preloaded on main memory (as they are small, like CIFAR or MNIST).
    • If in f I set the client’s model as a copy of the server model and then train it, when the functions will terminate will I still have these models as attribute of clients? Do I need to use share_memory() on them?
    • For this use case I could make use of mp.Pool to avoid creating processes each time, but on the docs I read that the processes created cannot make other process. This is not handy when using datasets that need to load data from disk, because i could not set num_workers >1 in the dataloaders. Any workaround?
  • Solution 2: use distributed
    The use of distributed package seems most general and would allow to use also more than one node. However it is even less clear how should I change my current code. Also, in the docs I read I cannot run two processes on the same GPU, while it would beneficial to me since the models I use are relatively small and I already experienced speedup in training multiple clients on the same GPU. Is there a way I can do it with distributed?

I hope I made clear my doubts, and that someone will help figuring out what is the best solution for my case.