How to try except GPU runtime error when OOM

I try to catch the except so that model can change to next gpu automatically, but it do not work

Have a look at the FairSeq example on how to recover from OOM errors.
They just skip the batch and try to continue the training. You could try to adapt this example to move your model.

1 Like

This seems like a really useful trick to have handy for single GPU training. How should this interact with Torch distributed? My first idea was to use the Join context manager, which handles a variable number ranks participating in all-reduce operations, but it seems like Join is a one way door, once you exit() the Join context, the placeholder ops won’t allow a rank to continue training. So it seems like you need to be able to dynamically participate in placeholder ops, which may not be trivial to implement. Maybe the solution is torch.distributed.elastic, where you just let the rank crash and spin it back up to join the process group?

Letting the rank crash might work but I would also assume restoring the training might not be trivial.
I would also expect to see the same memory usage in DDP so do you know why one rank runs OOM?

I’m primarily focused on large language models at the moment and long sequence lengths cause OOM issues. I could probably be more rigorous about finding the maximum sequence length (and batch size) that will fit in GPU memory though.