CUDA out of memory - on the 8th epoch?

Hey,

My training is crashing due to a ‘CUDA out of memory’ error, except that it happens at the 8th epoch. In my understanding unless there is a memory leak or unless I am writing data to the GPU that is not deleted every epoch the CUDA memory usage should not increase as training progresses, and if the model is too large to fit on the GPU then it should not pass the first epoch of training.

Any advice on how to debug this issue will be greatly appreciated as I am not sure myself where to even begin. Could faulty garbage collection be responsible? If so, are there any settings I can pass to Python’s/Pytorch’s GC to prevent this crash from occurring at the cost of performance? Are there any known bugs that could be the cause for this and if so are there any known workarounds? Does anyone know of any tools that can help me pinpoint the cause of this issue, such as a CUDA memory monitor that can tell me by how much is memory usage increasing every epoch?

2 Likes

Hi

Could faulty garbage collection be responsible?

Python is refcounted. So no issue here :slight_smile:

The usual reason this happens is:

  • You accumulate your loss (for later printing) in a differentiable manner like all_loss += loss. This means that all_loss keeps all the history of all the previous iterations. You can fix it by doing all_loss += loss.item() to get a python number that does not track gradients.
  • You reuse some Tensor that requires_grad from one iteration to the next. Effectively growing the graph of history needed to compute gradient at each iteration. You need to check elements that are reused from one iteration to the next.
  • You save some state in a list in a differentiable manner, add .detach() to break the history graph and only store the value of the Tensor (and not all its history).

Let me know if it helps !

7 Likes

I understand that this guarantees no leaks, but does that guarantee that GC will always be on time?

All of my loss accumulation is already using loss.item().

I was actually guilty of this. I am training a ‘hard’ attention recurrent model that directs a ‘window’ around an image to produce ‘glimpses’, and I was storing the past glimpse locations in a memory tensor. I instantiated this memory with torch.zeros(), which sets requires_gradients to False by default, but I didn’t realize that setting an element of this memory with a non-detached tensor would attach the entire memory tensor. I now call .detach() on the input tensor before inserting it into memory. I also posted about this on the repo I based my code on for anyone curious.

I’m restarting training now, and I’ll report back on whether this helped once training is over. Each epoch takes 30+ minutes so it will be a while :slight_smile:. Thanks for the help, saved me a day or two of pulling my hair out.

1 Like

The Python standard does not guarantee that. But CPython (the only python implementation we support) uses refcount and delete the objects as soon at it reach zero (there is no garbage collector involved). So yes as long as you use CPython (and you don’t create reference cycle, but that is fairly rare and should not happen).

Congrats on finding the issue !