Good afternoon,
I have a federated learning codebase that runs the simulation sequentially. I would like to use pytorch multiprocessing or distribute packages to be able to leverage multiple GPUs and (optionally) multiple nodes.
Brief description of my use case (FL)
A typical FL simulation consist of several rounds of training. At each round, a subset of compute entities (clients) run several optimization steps on a private local model and private data. At the end of each round, the gradients of each client are merged (e.g. weighted average) and used to update a global model, which is then used as initialization model for the clients selected at the next round.
The code I start from
The code I use is built around three classes:
CenterServer
: contains the global model, the aggregation logic of local models and testing codeAlgo
: contanins the main loop, in which the global model is sent to client, they perform local training and then the trained are sent to server for aggregationClient
: contains a dataloader and a copy of the model.
The CenterServer
object and all the Client
objects are all created inside Algo
(e.g. all the datasets and dataloaders).
What I want to do
At the moment the train loop sequentially trains all the clients, I would like to leverage multiple GPUs or, if possible, also multiple training nodes. I am asking for some advice about what is the best way to port my codebase so that I can achieve that.
-
Solution 1: use processes
I could use processes to make each client perform its local optimization in parallel. However I have some doubts about how the memory is used. Le it bef
the function passed toProcess
:- Do the args passed to
f
get copied? For example, the client object contains a dataloader, and most of my datasets are preloaded on main memory (as they are small, like CIFAR or MNIST). - If in
f
I set the client’s model as a copy of the server model and then train it, when the functions will terminate will I still have these models as attribute of clients? Do I need to useshare_memory()
on them? - For this use case I could make use of
mp.Pool
to avoid creating processes each time, but on the docs I read that the processes created cannot make other process. This is not handy when using datasets that need to load data from disk, because i could not setnum_workers
>1 in the dataloaders. Any workaround?
- Do the args passed to
-
Solution 2: use distributed
The use of distributed package seems most general and would allow to use also more than one node. However it is even less clear how should I change my current code. Also, in the docs I read I cannot run two processes on the same GPU, while it would beneficial to me since the models I use are relatively small and I already experienced speedup in training multiple clients on the same GPU. Is there a way I can do it with distributed?
I hope I made clear my doubts, and that someone will help figuring out what is the best solution for my case.