Hi PyTorch friends, is there a safe API that I can call to manually reshard FSDP ?
Context:
We’re trying an batch size auto-tuning idea. We use FSDP and start with a large batch size. We’ll catch the OOM exception, and try with smaller batch size. We repeat this until find a batch size that won’t OOM.
it’s like this:
While True:
try:
train_one_batch(fsdp_model, input_data)
except CUDA_OOM:
# reduce batch size
# i need to re-shard some of the fdsp modules manually here?
....
continue
break
But we found that, when we trying with smaller batch size after OOM, it seems GPU memory usage is higher than the beginning when we just initilized the model. I guess when OOM happens, it’s possible that some FSDP module is still in unsharded state, we need manually shard them? If my hypothesis is right, is there already a safe API i can use to shard the FSDP before training with samller batch size?