Understanding Memory usage for the DiT model

I’m having hard time understanding memory allocation during backward pass of a LLM. As I increase the depth of the LLM, Im seeing huge increment in the backward pass as compared to what estimated using The Mathematics of Training LLMs — with Quentin Anthony of Eleuther AI I’m running on a single GPU for now, but have the FSDP(FULL shard) enabled though. which I use to scale the model. Im using pytorch lightning to run the model

Depth mem before DiT blocks mem after DiT blocks mem after backward
2 6081 8445 27527
10 13201 16525 64533
20 22101 26865 OOM

I’ve tried to see if its just cache or something, but running torch.cuda.empty_cache() doesn’t really make the numbers very different as well, eventually the model ends up allocating the same memory, to run efficiently I believe.

Does anyone know why this might be the case or what tools I can use to debug this