GPU memory spikes with FSDP and torch.compile

Hi torch users,

Any ideas on whether it is expected to see an FSDP model (after being compiled) to re-compile things after certain number of training steps? Maybe someone faced similar problems? The code follows usual training procedure with FSDP:

model = ...
model = fsdp.FullyShardedDataParallel(...) # Wraps only decoder blocks in FSDP units
model = torch.compile(model)
optimizer = ...
# training loop
loss = model(X, Y)

Here is an example of a GPT2 test runs on 2 nodes , 8 GPUs each in FSDP mode:

You can see above that all of a sudden around ~180 seconds (the code is in the training loop, no evaluation or any other changes to the model) VRAM usage spikes.

Compiled model is - no doubt - faster than its default counterpart. But the latter at least has no such strange in-flight-vram-changing behaviors. Why is it bad? Because under high load where GPU memory usage is around 100%, such sudden spikes could result in OOMs.

Thank you for any insights!

Do you have a sense for whether ~180 seconds as indicated here would still be during a warmup step in your workload? In that case the behavior could be expected as a full forward pass is needed at a minimum to observe peak memory usage for tensors alone, and possibly one or more full forward+backward passes for torch.compile to trace the graph and JIT compile kernels. These JIT compiled kernels would then be expected to occupy additional device memory after they are generated.

Thanks for the quick reply.

Not really. On different model sizes this behavior still manifests itself but at a different point in training, however - usually close to the beginning. E.g. at a training step 42 :slight_smile: If I do a dummy forward pass on random input before entering the loop - the issue does not go away.

I wonder if there is a way to enable some logging to see what happens? OR maybe someone know a magical place to put the print() to in order to see and say “yes, it is JIT compiler doing things here” or “no, it is something else”. Related question: if I replace activation function implementation (that is - a simple python function) with its autograd’s version - the spike is smaller. Does that suggest a better idea of what it could be?

Are your input sequences always the same shape? If the shapes are dynamic then recompilations are expected

As far as logging goes TORCH_COMPILE_DEBUG=1 python is the single most useful flag available

Yes! Sequence length was the first thing I looked into. Thanks for the flag suggestion, I’ll try it.

@marksaroufim so around the spike time the only things this flag helps to log is:

[2023-05-04 17:19:26,426] torch._dynamo.eval_frame: [DEBUG] skipping <lambda> .../.venv/lib/python3.8/site-packages/torch/_dynamo/
[2023-05-04 17:19:26,426] torch._dynamo.eval_frame: [DEBUG] skipping _remove_id .../.venv/lib/python3.8/site-packages/torch/_dynamo/

Nothing else…