annisat
(MMLi)
August 27, 2020, 1:26am
1
Hi,
I’m working on modifying my model (including my custom data loader) to fit the structure of DDP. I haven’t given my code a try but I’d like to know more about the synchronization process.
According to the many great threads on this forum, DDP takes care of the synchronization during loss.backward(). But what if the number of data in each data loader leads to different for-loop counts, would the processes with n+1 loops be blocked because the processes with n loops never reach the point?
Say, I have 401 images, distributed to 4 data loaders with 101, 100, 100, 100 images respectively. Batch size is 4 so process 0 gets 26 iterations while others get 25. Would my process group get stuck at 26th iteration?
Here is a simplified version of part of my code:
#......(some init process including moving self.model to DDP)......
for phase in ['train', 'eval']:
dist.barrier()
if phase=='train':
self.model.train()
self.data_loader.train()
else:
self.model.eval()
self.data_loader.eval()
running_loss = 0
for inputs, labels in self.data_loader:
self.optimizer.zero_grad()
with torch.set_grad_enabled(phase=='train'):
outputs = self.model(inputs)
loss = self.loss(outputs, labels)
if phase == 'train':
loss.backward() ### Could this or the following line get stuck during the extra loop by process 0?
self.optimizer.step()
running_loss += loss.item()*inputs.shape[0]
torch.cuda.empty_cache()
epoch_loss = running_loss/len(self.data_loader)
Thanks for any helpful hint!
mrshenli
(Shen Li)
August 27, 2020, 2:48am
2
annisat:
According to the many great threads on this forum, DDP takes care of the synchronization during loss.backward(). But what if the number of data in each data loader leads to different for-loop counts, would the processes with n+1 loops be blocked because the processes with n loops never reach the point?
Yep, the one with n+1 loops will block when using <= PyTorch v1.6. There are ways to get around in user code, e.g. by collecting a signal in each iteration to see if any process has already exited. If yes, break.
@rvarm1 is working on a much better solution, which will be included in v1.7. With that solution, the process that exits early will use dummy comm ops to unblock remaining active ones. Please see the following issue and PR.
opened 02:00AM - 09 May 20 UTC
closed 08:29PM - 31 Aug 20 UTC
oncall: distributed
feature
triaged
## 🚀 Feature
with @pritamdamania87 @mrshenli @zhaojuanmao
This RFC is to summ… arize the current proposal for supporting uneven inputs across different DDP processes. Related discussion in https://github.com/pytorch/pytorch/issues/33148. An example pain point from a user is on the [PyTorch forums](https://discuss.pytorch.org/t/best-practice-for-uneven-dataset-sizes-with-distributeddataparallel/67308
).
#### Problem
[torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel) is a commonly used tool for distributed data-parallel training, but currently obliges the user to provide an equal number of inputs across each participating DDP process (or appropriately handle the error otherwise). DDP currently fails when different processes have an unequal number of inputs to process during training. While there are utilities such as [DataLoader](https://pytorch.org/docs/stable/data.html) and [DistributedSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler) that make navigating this assumption in DDP easier by evenly distributing the dataset, we can't expect these to solve all use cases and many users have had use cases where uneven inputs need to be supported.
The following script gives a simple example of the error:
```
import torch
import torch.distributed as dist
import os
import torch.multiprocessing as mp
import torch.nn as nn
def worker(rank):
dist.init_process_group("nccl", rank=rank, world_size=2)
torch.cuda.set_device(rank)
model = nn.Linear(1, 1, bias=False).to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
# Create uneven inputs, rank 1 will get one more input than rank 0. This will cause a hang.
inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
for _ in range(5):
for inp in inputs:
loss = model(inp).sum()
loss.backward()
torch.cuda.synchronize(device=rank)
if __name__ == '__main__':
os.environ["MASTER_ADDR"] = "localhost" ; os.environ["MASTER_PORT"] = "29501"
mp.spawn(worker, nprocs=2, args=())
```
With the NCCL backend (recommended choice when training with multiple GPUs) this will result in a hang, as process 1 will wait for communication (allreduce) from process 0, but process 0 has already exited its training loop. On the other hand, with the gloo backend, this results in a "Connection reset by peer" error.
#### Proposal
This was proposed by @pritamdamania87 and is influenced by the approach taken by Horovod to resolve a similar problem (https://github.com/horovod/horovod/issues/832).
1. Provide a context manager such as `with torch.nn.parallel.distributed.join()`. In the `__enter__`, we will set a flag indicating that we will run the below process for managing uneven inputs.
2. The context manager's `__exit__` indicates that the process has depleted its input and is ready to join. When a trainer calls this:
a. Schedule an allreduce with `torch.tensor(0)`. This allreduce will match the allreduce scheduled by non-joined processes (explained below in point 3)
b. If the result of the above is zero, this means that all processes have depleted their inputs and we can move to step (d)
c. Otherwise, schedule an allreduce for all buckets in the backwards pass, with all gradients zeroed out (this is so that joined ranks don't affect gradients of the rest of the training processes). This will match the allreduce done in the backwards pass for currently active trainers. Go back to step a.
d. If (a) returns all zeros, this means that all ranks have terminated their inputs and we can move on to cleanup. We also need to keep track of the process with the latest model parameters, and broadcast them to all ranks to maintain the fact that DDP ensures all parameters across processes are the same. We can do this via a simple version counter. In this step, we can then allgather this version counter, and have the process with the maximum counter broadcast its parameters to the rest of the processes. Ties can be broken arbitrarily.
3. If a trainer has not called `__exit__`, then:
a. Before scheduling allreduces for the backwards pass, schedule an allreduce with `torch.tensor(1)`. This allreduce matches the one scheduled in (2a). We can schedule this allreduce in the forward pass, but we should not await it here for performance reasons; it should be awaited at the end of the backwards pass.
b. Schedule allreduce ops for all the buckets as typical in the backwards pass for DDP. Processes which have depleted their inputs will match these allreduces as a result of step 2c. These processes will have zero as the argument for their gradients so they will not contribute to gradient averaging.
c. Instead of dividing by a static world_size, since we now can have a smaller effective world size (initial_world_size - currently_joined_processes), divide by this instead to ensure that we are still correctly averaging gradients. This can be done by taking the value returned in 3a, which will be interpreted as an int representing the number of currently active processes.
#### Code sample
```
model = nn.Linear(1, 1, bias=False).to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
# Create uneven inputs
inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
for _ in range(epochs):
with torch.nn.parallel.distributed.join():
for inp in inputs:
loss = model(inp).sum()
loss.backward()
```
#### Alternatives considered
We considered the alternative of all trainers raising a `StopIteration` once we detect that at least one trainer has depleted its input (via the above method). However, the user would then have to catch this StopIteration and this would also result in all processes stopping their training, whereas the currently proposed method allows training to continue with a smaller effective world size. In the future if we see the need for users to actually stop the training early in these situations, we can provide the appropriate options, although we would like to keep usage of this API as simple as reasonably possible.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar
pytorch:gh/rohan-varma/152/base
← pytorch:gh/rohan-varma/152/head
opened 03:07AM - 05 Aug 20 UTC
Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#42577 Join-based A… PI to support DDP uneven inputs**
Closes https://github.com/pytorch/pytorch/issues/38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in https://github.com/pytorch/pytorch/issues/38174. Details are available in the RFC, but as a summary, we make the following changes:
#### Approach
1) Add a context manager that is owned by `class DistributedDataParallel` to coordinate the below process.
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm.
4) We provide an option of whether we should divide by the initial world_size or effective world_size when some ranks are gone (default to initial world_size). If dividing by effective world_size, we adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.
We also make the following smaller changes to support the above:
- Add a `rank` argument to `_distributed_broadcast_coalesced` to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
- Add a helper function to `DistributedDataParallel` which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
- Expose several helper methods on `Reducer` such as getters for the `Reducer`s `Bucket`s and the ability to invoke `rebuildBuckets()` from Python, to support "shadowing" collective calls in join mode.
#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test gradient division by the effective world_size (no. of unjoined processes) and the static world_size
- [x] Test that exceptions during training are correctly propagated by the context manager
- [x] Test expected behavior when the manager is disabled with `enable=False` (for debug purposes)
- [x] Test expected behavior when > 1 process joins early (at different iterations)
- [x] Test model equivalence to local training when used with join API.
#### How to run the tests
The main test can be run with `touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed.py -v TestDistBackend.test_ddp_uneven_inputs`
#### Performance implications
###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s
###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s
###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s
###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s
###### join(enable=True) with uneven inputs (offset = 2000), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s
The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.
###### join(enable=True) with uneven inputs (offset = 10), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.120 265/s
P50 backwards only batch latency & total QPS: 0.087 367/s vs 0.087 367/s
This means that there is only a difference of 10 in the uneven inputs, i.e. the early joined ranks only iterate 10 times less than the ones that iterate for the full N, instead of an all-or-nothing in the above tests.
#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](https://github.com/pytorch/pytorch/issues/39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
4) Has not been thoroughly tested with DDP + RPC. We plan to add support for this in follow up PRs.
Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!
annisat
(MMLi)
August 27, 2020, 4:56am
3
Thousand thanks for the explanation! I modified my code following your suggestion and I provide my provisional solution here for comments.
running_loss = 0
running_len = 0
for inputs, labels in self.data_loader:
self.optimizer.zero_grad()
with torch.set_grad_enabled(phase=='train'):
outputs = self.model(inputs)
loss = self.loss(outputs, labels)
if phase == 'train':
loss.backward()
self.optimizer.step()
iteration_count+=1
running_loss += loss.item()
running_len += inputs.shape[0]
torch.cuda.empty_cache()
##########
is_next = torch.Tensor([self.data_loader.peek()])
# is_next==True if the iterator has not reached the end, i.e., next loop is expected
dist.all_reduce_multigpu(is_next, op=dist.ReduceOp.BAND)
if not is_next: break
##########
mrshenli
(Shen Li)
August 27, 2020, 2:17pm
4
Hey @annisat , that looks good to me. One way to speed it up a bit is to run the dist.all_reduce
at the beginning of the loop and set async_op=True
. Then only wait for it when you need the result. In this way, the comm and the forward/backward/opt.step computation can overlap. Please see the code in the following thread:
The self-contained code below works for me.
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Linear(10, 10).to(rank)
# construct DDP model
ddp_model = DDP(model, device_i…
annisat
(MMLi)
September 2, 2020, 3:30am
5
Thanks for the tips! It took me some while to understand and implement async_op.
I would like to point out a problem when I ran my own code above.
I changed my code to
is_next = torch.Tensor([self.data_loader.peek()]).cuda(self.gpu)
col_handle = dist.all_reduce(is_next, op=dist.ReduceOp.BAND, async_op)
...
col_handle.wait()
if not is_next: break
and tried it with SPSG with 2 processes. The final value of is_next
is [2]
rather than [True]
or [1]
. It seems that dist.ReduceOp.BAND
adds up input tensors rather than doing a regular AND. Therefore I changed the first line into:
is_next = torch.Tensor([self.data_loader.peek()]).bool().cuda(self.gpu)
The Error Message says all_reduce does not support this Tensor type for now. In order to achieve my goal, I use dist.ReduceOp.MIN
instead. Here’s my final code that actually runs smoothly without imbalanced for-loop counts blocking the synchornization process.
for inputs, labels in self.data_loader:
is_next = torch.Tensor([self.data_loader.peek()]).cuda(self.gpu)
col_handle = dist.all_reduce(is_next, op=dist.ReduceOp.MIN, async_op=True)
# forward and backward and step and stuff
col_handle.wait()
if not is_next: break