I have a small model (2M params), and I’m using batch = 1. The size of every batch varies, but in average, I’m using 5 gb per iteration for the first epoch.
My problem is that after the first epoch, my memory consumption increases until it hits an OOM. I’m suspecting this due to a memory leak, so I tried the following fixes:
torch.cuda.empty_cache()after each memory heavy operation
- delete my mini batch data and loss after each iteration
- detach and convert to cpu all log data
but this didn’t seem to solve my issue!
Is there something I’m missing? can you help me solve this issue?
Thank you in advance!