LSTM with for loop OOM

Hi,

I want to unroll an LSTM step by step with a for loop (in order to do attention), but I run out of memory instantaneously. I’ve seen this issue discussed before (https://github.com/pytorch/pytorch/issues/914) and I understood the reason would have been the cudnn. I’ve disabled it with torch.backends.cudnn.enabled = False, but still OOM. I also understood from that post that this issue was later fixed, but I am using the latest version of Pytorch and cudnn v7 (when enabled). I’ve also tried with nn.LSTMCell instead and still no improvement. While with no-unrolling I can use a batch size of 64, with unrolling it would only run with batch size of 8.

Could you please advise what is the current best way of unrolling an RNN with a for loop?

Thanks,
Oana

1 Like

you likely are holding onto some variables across time, which are holding onto buffers needed for backward.

Most common mistake is doing:

total_loss += loss instead of total_loss += loss.data[0]