I’m trying to optimise memory requirements for seq2seq decoder when every input for decoder is taken from previous step’s output (non-teaching mode).
In that case, I can’t use pack_padded_sequence method and execute RNN on full batch, but iterate over sequences offsets, accumulating loss from every step. From my experiments, I found that in such mode, GPU memory consumption becomes almost linear to input sequence length, which, in case of even small LSTM nets limits sequences to 300-400.
I’ve implemented small demonstration tool, which generates random batch of sequences and iterates the loop over their offsets: https://gist.github.com/Shmuma/614d3dbe0ad2805d048ff0e6129682aa
Results from it’s run on gtx 1080ti and pytorch 0.1.12:
- seq_len=50 -> 1.5GB
- seq_len=200 -> 5.0GB
- seq_len=300 -> 7.2GB
- seq_len=400 -> 9.5GB
I guess, such high memory consumption is due to gradients accumulated during loss summing.
So, my question is: is it possible to optimise memory consumption in such case?
My understading of this that gradients for every LSTM matrix should be aggregated somehow among all sequence steps, but they are retained in separate buffers until final .backward() call. Is it possible to achive this?
Other option would be to call .backward() for every sequence step, but in this case, it doesn’t look like an option, as decoding step is preceeded with encoder run, and I’m not sure that encoder’s gradients will be valid.