I am currently working on a generative algorithm for discrete data (MCTS-like), based on Transformers, where I sample / extend tokens (nodes) to build a tree.
The Transformer is causal, so for speed purposes I tried to store in memory the hidden states of each previous node to not recompute them.
Doing so, I need to move them to the CPU in order to not run out of GPU memory. Here is the bottleneck, it’s very slow. I ran some benchmarks, here are the average time per iteration (I refer to an iteration as creating a new node and running a simulation):
- reusing hidden states and storing them on the CPU: 9.4sec / it
- reusing hidden states, keeping on GPU (until running OOM): 1.06sec / it
- recomputing all the hidden states during each forward pass: 6.76sec / it
I ran this on a 2080S 8GB, CPU is an i7 10k something with 32GB of ram, sequences of 512 tokens with a 12 layer Transformer (about 38M params).
I was very surprised (and deceived) by these results: factor of 9 and just faster to recompute everything at each forward pass
Do you have some clues / possible solutions that could help ?
If nothing is possible I am considering limiting the tree size / memory usage. I am running OOM with ~140 nodes stored, but I would like to increase it to at least 200 (the more the better).
I did not tried FP16 yet as I would need to make some changes in the code to make it compatible.
Thank you in advance !