Got CUDA OOM Error after 8 Hours of Training

Hi, I’m training a large CLM on 8 GPUs using FSDP with gradient checkpointing. When training a CLM, I concatenated all the tokenized examples together and then split them into blocks with exactly the same length. The training process ran smoothly for 8 hours, and then it was terminated due to a CUDA OOM error.

I’m trying to resume the training from the latest checkpoint. But now I get the CUDA OOM error much sooner, and it happened totally randomly. Sometimes the error happened just at the second step after resuming the training. While sometimes it happened at the 25 step, resuming from the same checkpoint. But after resuming from the latest checkpoint, the training could never run for more than 30 min.

My questions are:

  1. Given all my training examples have exactly the same length (and I don’t do eval during training), shouldn’t the memory consumption be identical from step to step? Why the training process can run smoothly for 8 hours and then suddenly it got CUDA OOM error?
  2. After resuming from the latest checkpoint, why do I get the CUDA OOM error much sooner? And why does it happen randomly, even if I’m using the same data, same hyperparameters and same seed?

I tried to explicitly call Python’s garbage collector and also clean cuda cache in each training step:

class CLMTrainer(transformers.Trainer):
    def training_step(self, model: nn.Module, inputs) -> torch.Tensor:
        res = super().training_step(model, inputs)
        return res

But it doesn’t help.

My training env:

  • torch: 2.0.0+cu118
  • cuda: 11.8
  • transformers: 4.48.1

Below are the training curve and GPU memory consumption plot in the first 8 hours of training (before getting the CUDA OOM error). Everything looks fine and I didn’t see any overall increase of GPU memory.

Any commands or suggestions would be appreciated. Thanks in advance!