Resume DDP training after out of memory error

I am training a NAS network that would sometimes be out of memory on the GPUs due to different predicted model configurations on each run.

def check_oom(func):
    def wrapper(*args, **kwargs):
                  return func(*arg, **kwargs)
            except RuntimeError:
                  return None
   return wrapper

# model's forward function
def forward(model, input):
       return model(input)

# main loop
def main():
      dataset = Dataloader(...)
      net = Model(...)
      for input in dataset:
             output = forward(net , input)
             (....... training code)

above is the pseudo-code I use for my training. However, in practice, OOM events will hang the entire training with GPUs 's maxed out at 100% utility rate.
What can I do ?

Hey @Scott_Hoang, yes, this is the expected behavior, as OOM in one of the process will lead to AllReduce communication desync across the entire DDP gang. is built to solve this problem. It will destruct all DDP instances across all processes, reconstruct a new gang, and then recover from the previous checkpoint.

This is exactly what I am looking for. Thank!