LSTM memory consumption issue

Hello,
I have a network that uses 8 LSTM layers of size 512 each, with a batch size of 1024 and a seq length of 512. The result seems to cause python to use 100% of memory and cause my PC to freeze up, even though I have upgraded to 32 GB of RAM.
I am using a large batch size because I have a very large training dataset so it seems to make sense to get through the data quicker. When I make a smaller batch size and seq length then it doesn’t freeze up, but is unreasonably slow.

  1. Is there a way to have PyTorch not consume so much memory?
  2. Would it help if I increased the RAM again? It seems like it made no difference
  3. Is the size of the batches and / or architecture unreasonable? My limited experience with RNNs seems like the architecture isn’t that big, but maybe I’m mistaken

Thanks,
Nathan

Based on your description I guess that either the swap is used as your RAM is already full or the process gets indeed stuck and the OS should kill it.

  1. If your current training really needs this amount of memory, there wouldn’t be a workaround besides e.g. lowering the batch size or using torch.utils.checkpoint to trade compute for memory. Check if you are storing unnecessary data and could remove it.
  2. Maybe, but you should check how much the memory usage increases depending on the batch size. E.g. check the RAM usage for a batch size of 1, 2, 3, … and try to estimate how much memory would be needed for your desired batch size.