Handeling OOM in backward pass using DDP

Sometimes OOM errors occur, and typically a way I do to handle this is the following:

for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
    try:
        iteration_output = _do_iteration(...)
        output = iteration_output.output_image
        loss_dict = iteration_output.data_dict
    except RuntimeError as e:
        # Maybe string can change
        if "out of memory" in str(e):
            if fail_counter == 3:
                raise TrainingException(
                    f"OOM, could not recover after 3 tries: {e}."
                )
            fail_counter += 1
            logger.info(
                f"OOM Error: {e}. Skipping batch. Retry {fail_counter}/3."
            )
            optimizer.zero_grad()
            gc.collect()
            torch.cuda.empty_cache()
            continue
        logger.info(f"Cannot recover from exception {e}. Exiting.")
        raise RuntimeError(e)
    fail_counter = 0

Note: While parsing the error string is suboptimal, it does not appear there is an alternative (I opened an issue about that in GitHub: https://github.com/pytorch/pytorch/issues/48365).

This above works well in the forward pass, but if the error occurs somewhere in the backward pass, and you use DistributedDataParallel, you get an exception such as:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).

Likely some parameters have already a computed derivative and then the model ran out of memory somewhere in the backward pass. Since this is very tricky to debug, I would appreciate any pointers where to look.

1 Like

Thanks for filing the GitHub issue! Just to confirm the scenario you are seeing, sometimes you see an OOM in the fwd pass (which is handled by your try-catch block), whereas an OOM in the bwd pass results in the RuntimeError you posted.

Are you able to complete an iteration of training without seeing an OOM? If not, the runtime error may actually be due to some value returned by the fwd function that’s not used in the loss computation.

Regarding future debugging, first here is another thread about why the torch.cuda.empty_cache() function is not recommended. I can think of the following ways to get around the OOM issue in a more robust way:

  • Use Model Parallelism. For CPU-based models you can check out the RPC framework (We are working on robust GPU support for the RPC framework). Otherwise you can split the model manually call the forward functions on each shard and move activations around using .to(). Here is a recent question about this.
  • Try reducing the batch size
  • Use an optimizer that needs to store less local state (SGD vs. Adam)

@osalpekar Thanks for your reply.

The above code indeed works when an error occurs during the forward pass, skipping a batch and happily continuing, but when the error occurs in the backward pass I get the exception above (tested by increasing batch size until a OOM occurs during in the forward pass).

In general, it works well, and I do not yet need to use model parallelization, I do get an OOM error (about once a day). To get a bit of an idea, this is for my MRI reconstruction project, where I can use 40GB of memory, the model is typically around ±34GB (batch size 1 per GPU), but I can get an OOM in the backward pass.

Not particularly sure why and how this happens, but seems to be rather deep in the pytorch internals, and checking how they solve it in e.g. detectron2 it seems like a pragmatic way to solve it this way.

Blatantly ignoring the above exception and continue with the next batch just freezes the training by the way, so there should be something I need to reset,

@jteuwen I see, thanks for the added context!

It indeed seems like workflow may be OOM-prone given that a 34GB model, corresponding gradients and optimizer states, and an input sample must fit into 40GB. Is my understanding of these memory sizes correct? There has been some work on DDP to reduce the memory overhead, perhaps @mrshenli may be able to shed some more light on that.

As an aside, torchelastic is a great way of recovering from errors and restarting training from a checkpoint. I’m curious whether it will result in the GPU tensors being freed (which could replace the failure recovery script shared above) cc @Kiuk_Chung

3 Likes

There has been some work on DDP to reduce the memory overhead, perhaps @mrshenli may be able to shed some more light on that.

Thanks @osalpekar. The feature is gradient_as_bucket_view arg in DDP ctor. It will modify param.grad field and let it point to DDP communication bucket views. So that can save one copy of the model. This is still a prototype feature and subject to changes.

https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=distributeddataparallel#torch.nn.parallel.DistributedDataParallel

+1 to @osalpekar’s comment that torchelastic is the recommended solution for DDP OOMs. The RuntimeError you saw is caused by desync, and the desync is caused by OOM in one process. Because DDP expects all processes to launch the same number of AllReduce comm ops in every iteration. If one process hit OOM and skip/redo some comm ops, it will break this assumption. TorchElastic handles this problem by checkpointing model states and let the entire DDP gang to revert to the previous checkpoint when it detects failures in any process, which seems to be very nice fit for this infrequent and unpredictable OOM.

4 Likes