I’m having hard time understanding memory allocation during backward pass of a LLM. As I increase the depth of the LLM, Im seeing huge increment in the backward pass as compared to what estimated using The Mathematics of Training LLMs — with Quentin Anthony of Eleuther AI I’m running on a single GPU for now, but have the FSDP(FULL shard) enabled though. which I use to scale the model. Im using pytorch lightning to run the model
Depth | mem before DiT blocks | mem after DiT blocks | mem after backward |
---|---|---|---|
2 | 6081 | 8445 | 27527 |
10 | 13201 | 16525 | 64533 |
20 | 22101 | 26865 | OOM |
I’ve tried to see if its just cache or something, but running torch.cuda.empty_cache() doesn’t really make the numbers very different as well, eventually the model ends up allocating the same memory, to run efficiently I believe.
Does anyone know why this might be the case or what tools I can use to debug this