I’m working on a model where I need to apply the same shared LSTM layer to every slice of an input tensor along a particular dimension (sort of like an LSTM convoultion). Unfortunately, this easily causes CUDA memory error when the LSTM has > 50 units.

For example, I’m trying to apply a 50 unit LSTM on every slice of the second dimension. If the second dimension has size 20, the LSTM will output 50 units for every slice of the dimension, totaling 20x50. This slicing is done via for loop and stacking the output tensors.

In my Tensorflow implementation, I was easily able to have up to 256 LSTM units. However a Pytorch implementation runs it of GPU memory. Is there a memory efficient way to handle this? Is there also a wat to parallel process each slice as they don’t depend on the computation of each other?