I am using PyTorch’s FSDP. How can I save the completed model state every 100 batches? Is the following code correct?
if (batch + 1) % 100 == 0:
if isinstance(model, FSDP):
states = model.state_dict()
if is_main_process:
ckpt = {'model_state_dict': states}
torch.save(ckpt, 'ckpt.pth')
else:
ckpt = {'model_state_dict': model.state_dict()}
torch.save(ckpt, 'ckpt.pth')
or method 2?
with FSDP.summon_full_params(
module=model,
rank0_only=True,
writeback=False,
offload_to_cpu=True
):
states = model.state_dict()
ckpt = {'model_state_dict': states}
torch.save(ckpt, 'ckpt.pth')