Resume DDP training after out of memory error

I am training a NAS network that would sometimes be out of memory on the GPUs due to different predicted model configurations on each run.

def check_oom(func):
    def wrapper(*args, **kwargs):
            try:
                  return func(*arg, **kwargs)
            except RuntimeError:
                  torch.cuda.empty_cache()
                  return None
   return wrapper

# model's forward function
@check_oom
def forward(model, input):
       return model(input)

# main loop
def main():
      dataset = Dataloader(...)
      net = Model(...)
      for input in dataset:
             output = forward(net , input)
             (....... training code)
     torch.cuda.synchronize()

above is the pseudo-code I use for my training. However, in practice, OOM events will hang the entire training with GPUs 's maxed out at 100% utility rate.
What can I do ?

Hey @Scott_Hoang, yes, this is the expected behavior, as OOM in one of the process will lead to AllReduce communication desync across the entire DDP gang. https://pytorch.org/elastic is built to solve this problem. It will destruct all DDP instances across all processes, reconstruct a new gang, and then recover from the previous checkpoint.

cc @Kiuk_Chung

1 Like

This is exactly what I am looking for. Thank!