Checkpoint for distributed asynchronous training

Hi, every one.
I read some documentation and found that many ML systems implement distributed checkpoints in a synchronous way, i.e. for every n steps, all nodes stop and dump the state.

Then I am curious how should PyTorch implements distributed checkpoint when it is doing asynchronous training. e.g. Node 1 is at iteration 10, Node 2 is at Iteration 15, etc

Does anyone know how PyTorch checkpoints the state for asynchronous training?

Thanks!

Hey @steamgjk do you mind share which async training paradigm are you using? E.g., Parameter Server, Gossip, etc.

I prefer Parameter Server, but I think this is a common question. But let’s discuss it in the domain of PS-based arch @mrshenli

Since it is asynchronous, every worker can be at different speed. Mayber Worker 1’s weights after 1st iteration and Worker 2’s weigths affter 5th iteration come together and been merged into the parameters stored at the PS. If every worker and every PS still checkpoints individually without any coordination, after some failure, and checkpoint is not global consistent and cannot be used for correctness, right?

Will each trainer host a complete model instance in one iteration? If so, any trainer can do that, right? Because you don’t have to save the most up-to-date checkpoint. It would work as long as it provides a reasonable checkpoint to roll back to. Yep, there will be a waste of work when failure occurs, but such problem also exist with synchronous training if you don’t save model on every iteration.

Let’s consider this scenario:
There are 2 pses and 3 workers. PS-1 holds half model parameters and PS-2 holds the other half. During normal operation, worker 1 just pushes its parameters to both PSes. Since there is no coordination, it is likely that PS-1 checkpoints its local state after Worker 1’s update has been merged but PS-2 checkpoints its local state before Worker’1’s update has been merged. In that way, the global snapshot is not consistent. When PSes fail and they recover from the snapshot (i.e. half parameters include Worker 1’s update and half does not), wouldn’t that cause incorrectness? @mrshenli

Is it possible to checkpoint it on the worker instead of the PS? Since the worker runs the forward through the full model, I assume parameters in that model should be consistent?

Hi, Shen. @mrshenli
I think about it for a while and l later find “checkpointing worker” can cause inconsistency problems.

Again in asynchronous training: Suppose there are 2 PSes and 3 workers.
Worker 1 is training with Data-1, Data-2 and Data-3.
Worker 2 is training with Data-4, Data-5 and Data-6.
Worker 3 is training with Data-7, Data-8 and Data-9.

Then:

  1. All workers pull parameters (weights) from PSes.
  2. Worker 1 uses Data-1 to train and calculate Update-1, Worker 2 uses Data-4 to train and calculate Update-4, Worker 3 uses Data-7 to train and calculate Update-7, they all push to PSes
  3. Worker 1 pulls the parameters only including Update-1 and Update-4, to continue training with Data-2, Worker 2 pulls the parameters only including Update-1 and Update-4, to continue training with Data-5, Worker 3 pulls the parameters including Update-1, Update-4 and UPdate-7, to continue training with Data-8, [It already becomes messy now]
  4. After several iterations, every worker has a model, including updates of different ages. Then, somebody fails and every worker has a saved model. Here comes the question: which worker’s model should be used? If we use that worker’s model, which Data should the other workers start training from?