I am coming across some behavior I am unable to explain:
- From the docs, my understanding is that setting
batch_size=64in my dataloaders withsplit_batches=Falsefor accelerate would have an effective batch size of 512. This runs without issues. However,batch_size=512withsplit_batches=Truewould also have the same global batch size but this gives me OOM. Shouldn’t both these approaches give me the same results (or at least no OOMs)? - I also would like to know the right way to call non-forward methods of my
model.
Is the recommended way to always run it within anFSDP.summon_all_gather()? In what cases should I setignored_modulesin the FullyShardedDataParallelPlugin*,* and when should I choose otherwise?
Here’s a tiny code snippet showing how I intend to use model.generate()
for iter, batch in enumerate(train_loader):
outputs = model(batch)
accelerator.backward(outputs.loss)
optimizer.step()
scheduler.step()
if iter % 5 == 0:
with torch.inference_mode(), FSDP.summon_full_params(model):
eval_text = model.generate(batch)