I am training a NAS network that would sometimes be out of memory on the GPUs due to different predicted model configurations on each run.
def check_oom(func):
def wrapper(*args, **kwargs):
try:
return func(*arg, **kwargs)
except RuntimeError:
torch.cuda.empty_cache()
return None
return wrapper
# model's forward function
@check_oom
def forward(model, input):
return model(input)
# main loop
def main():
dataset = Dataloader(...)
net = Model(...)
for input in dataset:
output = forward(net , input)
(....... training code)
torch.cuda.synchronize()
above is the pseudo-code I use for my training. However, in practice, OOM events will hang the entire training with GPUs 's maxed out at 100% utility rate.
What can I do ?