I’m looking at trying to improve the robustness of our trainer code. If I select batch parameters that are not tight enough, I may run out of memory. Is it safe to catch the OOM, reduce the batch size, and try again?
Thanks
Jerry
I’m looking at trying to improve the robustness of our trainer code. If I select batch parameters that are not tight enough, I may run out of memory. Is it safe to catch the OOM, reduce the batch size, and try again?
Thanks
Jerry
Bump. Does anyone know the answer?
What does a CUDA OOM error look like, out of curiosity?
THCudaCheck FAIL file=/u/jlquinn/jlq01/pytorch/pytorch/torch/lib/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
File “/dccstor/jlquinn01/mnlp-nn-rl-kit/src/mnlp/nn/seq2seq/tools/trainSeq2Seq.py”, line 320, in
nll.backward()
File “/dccstor/jlquinn-mt/nmt-env/ppc64le/lib/python2.7/site-packages/torch/autograd/variable.py”, line 167, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
File “/dccstor/jlquinn-mt/nmt-env/ppc64le/lib/python2.7/site-packages/torch/autograd/init.py”, line 99, in backward
variables, grad_variables, retain_graph)
File “/dccstor/jlquinn-mt/nmt-env/ppc64le/lib/python2.7/site-packages/torch/autograd/function.py”, line 91, in apply
return self._forward_cls.backward(self, *args)
File "/dccstor/jlquinn-mt/nmt-env/ppc64le/lib/python2.7/site-packages/torch/autograd/functions/tensor.py", line 95, in backward
grad_input = grad_output.data.new(ctx.input_size).zero()
RuntimeError: cuda runtime error (2) : out of memory at /u/jlquinn/jlq01/pytorch/pytorch/torch/lib/THC/generic/THCStorage.cu:58
it is safe to recover from this. The fairseq-py guys have started doing this (in case there is a really long sequence length that comes in…).
Catch the python exception and then call torch.cuda.empty_cache()
and continue business as usual.
Thanks! It’s good to have the safety mechanism.
How would you catch this exception?
Like this:
try:
...
catch RuntimeError:
torch.cuda.empty_cache()
Or what is the exact exception for a Cuda OOM Error?
I’d parse the error string to determine that the runtime error is actually a cuda error.
My program has various sized inputs, sometimes OOM occurs, I used the above technique, like @marcel1991’s pseudocode, and continue to another iteration if OOM occurs, it seems that the torch.cuda.empty_cache()
is not enough to clean up the memory, and I could actually not continue to train – every iteration after that suffers from OOM, how can I clean up the allocated GPU memory when OOM occurs?
Assume your training code (e.g. Trainer) can be viewed as a class object that contains the network, optimizer and etc.
My workable solution is like this.
import gc
retry_time = 0
while retry_time < 3:
try:
trainer = Trainer() # Init the network and optimizer inside
result = trainer.train()
break # Still goto finally after break while
except RuntimeError as e:
retry_time += 1
print('Runtime Error {}\nRun Again......{}/{}'.format(e, retry_time, 3))
if retry_time == 3:
print('Give up!')
break # Still goto finally after break while
finally:
# Handle CUDA OOM Error Safely
del trainer
gc.collect()
torch.cuda.empty_cache()
You can monitor the memory usage after the last line of the finally
block before the next trial to see if the memory are released as you expect.