I am training a NAS network that would sometimes be out of memory on the GPUs due to different predicted model configurations on each run.
def check_oom(func): def wrapper(*args, **kwargs): try: return func(*arg, **kwargs) except RuntimeError: torch.cuda.empty_cache() return None return wrapper # model's forward function @check_oom def forward(model, input): return model(input) # main loop def main(): dataset = Dataloader(...) net = Model(...) for input in dataset: output = forward(net , input) (....... training code) torch.cuda.synchronize()
above is the pseudo-code I use for my training. However, in practice, OOM events will hang the entire training with GPUs 's maxed out at 100% utility rate.
What can I do ?