EDIT - This was resolved and is not an issue. The “bug” was an unexpectedly large usage of memory by RNNs. Sorry about that. Keeping the post for record.
So this is post off an issue I reported on the slack channel. Thought it would be better to post it incase someone else has this problem or a solution.
My model is an encoder-decoder + attention using LSTMs. Decoder input has max length 140. I’ve noticed a couple of things regarding memory usage. I’m using a Titan X with 12G of memory
What I observe is each
backward call allocates extra memory and I get OOM error, but this doesn’t happen if I decrease length of decoder inputs below a value (50). The total memory is stable at nearly 12G
However, if I instead decrease
batch_size to 1, I get OOM eventually even though my initial memory allocation is at around 9G. Extra memory is allocated for few iterations before stabilizing.
Is there anything under the hood which increases memory usage for first few iterations?