Multi model and multi forward in distirbuted data parallel

Q1: If I have two models named A and B, both wrapped with DDP, and loss = A(B(inputs)), will DDP work? It seems that gradients will be sync when loss.backward() is called.

Q2: If loss = A(B(inputs1), B(inputs2)), will DDP work ? The forward funciton of B is called twice . btw, I don’t know what does reducer.prepare_for_backward do…

It should work. This is using the output from B(inputs) to connect two graphs together. The AllReduce communication from A and B won’t run interleavingly I think. If it hangs somehow, you could trying setting the process_group argument of two DDP instances to different ProcessGroup objects created using the new_group API. This will fully decouple the communication of A and B.

It seems that gradients will be sync when loss.backward() is called.

Yes, see this page for more detail: Distributed Data Parallel — PyTorch 2.1 documentation

Q2: If loss = A(B(inputs1), B(inputs2)), will DDP work ? The forward funciton of B is called twice . btw, I don’t know what does reducer.prepare_for_backward do…

This won’t work. DDP requires forward and backward to run alternatively. The above code would run forward on B twice before one backward, which would mess up DDP internal states. However, the following would work. Suppose the local module wrapped by B is C


class Wrapper(nn.Module):
    def __init__(self):
        self.c = C()

    def forward(inputs):
        return self.c(inputs[0]), self.c(inputs[1])

B = DistributedDataParallel(Wrapper(), ...)

loss = A(B([input21, inputs2]))

This is basically using a sheer wrapper over C to process two inputs in one forward call.

2 Likes

Hi,
I have a different but related problem.
I have a detection model with unfixed input size. At some extreme case, the RuntimeError with OOM occurs, so I wrap the forward+backward around try+except

for images, targets in data_loader:
        images = images.to(device)
        targets = [target.to(device) for target in targets]
        try:
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            optimizer.zero_grad()
            losses .backward()
        except Exception as ex:
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()
            continue

this would help to some extend, until:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing its output (the return value of forward). You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel. If you already have this argument set, then the distributed data parallel module wasn’t able to locate the output tensors in the return value of your module’s forward function. Please include the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable). (prepare_for_backward at /opt/conda/conda-bld/pytorch_1556653099582/work/torch/csrc/distributed/c10d/reducer.cpp:408)

To be clear, I already set find_unused_parameters=True, and broadcast_buffers=False. My guess is some internal state was strained somehow. Is it possible to reset all the state during my except catch?

Hey @qianyizhang

Yes, the error is expected. Because, say you have two DDP processes, X and Y. If process X hits a OOM error in the forward pass of one iteration but Y runs correctly, as a result, X would skip its backward pass in that iteration causing a de-synchronization. DDP itself cannot recover from this error.

However, torchelastic is built to solve this exact problem. It would kill the entire DDP gang, reconstruct a new DDP gang, and revert to the previous checkpoint when such OOM occurs. cc @Kiuk_Chung

2 Likes

Adding a bit more context to @mrshenli’s comments, you could try to reset the DDP state by calling destroy_process_group() and re-initializing it, however that doesn’t guarantee that your tensors (distributed among multiple workers) are also reset. In short, a complete state reset on the worker is application dependent (and often non-trivial). For transient exceptions you can use torchelastic to launch your workers, and just let the worker process throw the exception out and fail. Elastic will monitor the worker pids and will restart the world if it detects that one (or more) workers have failed. Note, that due to this behavior, (assuming you have checkpoints) you will lose progress between checkpoints.

2 Likes

@Kiuk_Chung Can your elaborate more on the last part?
My goal is to keep the training process going with minimum withdraw from failed synchronization.
I assume your elastic feature offers the flexibility of leaving + rejoining the sync pool at any given point?
If one failed worker = all process have to restart from last checkpoint, it’s basically the same as I ran a background monitor process constantly checks->kills->restarts the whole process, which is not very efficient…

the funny part is the mentioned RuntimeError only happens 10% of the time, while the training process could tough it out most of times as if one worker has a slow start (by retraining a second round). I suspect it hits OOM before finish computing first bucket of gradients…

Its the later, elastic monitors worker pids and restarts the world whenever one or more pids fail. Elastic was designed to be more general purpose to fit a variety of distributed training use-cases. For pure torch rpc apps it might be possible to allow workers to leave/join the job on the fly, but if you are using process groups the backend may not allow this by design - for instance if you are using NCCL backend, NCCL itself does not allow resizing of communicators unless you destroy them and restart them.

@Kiuk_Chung I see.
But is it possible to reset DDP state on one worker (with NCCL backend)? Without resetting everything, which is somewhat expansive.
After all, in my case the connection is not lost, it’s simply halted with OOM. The ideal scenario could be:
(assuming a bad case where bucket 1 gradients already reduced, and stuck on bucket2 with worker#0 OOM)

  1. worker#0 rerun another batch and only sync bucket2 + later, to “catch up” with the rest.
  2. worker#0 continue to sync while only sending 0.0 gradients without further computation.

I imagine solution#1 is very tricky if not impossible, whereas solution#2 is very much feasible?
If so, can you point me some directions to make it work? (files, watch-outs, etc.) Thanks!

Thanks, this helps me a lot !

If the collective operation fails (with an OOM or other exception) and you are able to catch it you may be able to reissue the operation. Whether this works or not depends on whether the distributed backend is still in a good state or not. With NCCL once the state goes bad you have to destroy the process group and re initialize it. You could try to salvage your processes but you’d have to call destroy and initialize on all your workers together. There may be other states in your application that goes out of sync whenever there is an exception only observed in a subset of your workers. if you are able to recover that state you can try to restore the workers into a well defined point in time. Based on what we’ve observed with distributed applications it’s non trivial to restore the full application state (ddp + user code) in a distributed setting and the cleanest restore is to tear the processes down and restore from a checkpoint. This is why elastic was designed as such. Trade off is restart overhead versus correctness and maintainability. Note that with elastic you are simply restarting your worker processes so the penalty you pay is your initialization time (loading the model, allocating mem for data, etc) and not actually restarting the node or container.

Thank you very much. This thread has saved me a lot of headache!