Hi,
I am training on 4 A100s using accelerate
in bf16 precision and torch.compile.
I am not sure if this is an accelerate
issue or the underlying problem is some logic I don’t understand in torch
.
Training is overall stable but sometimes collapses (but doesn’t diverge) when saving a checkpoint. The collapse manifests as very low gradient norms and the optimizer skipping steps. Sometimes the training recovers after the next checkpoint save (every 5k steps in the charts). Something weird happens whenever I save a checkpoint (or whenever I call accelerator.wait_for_everyone()
), but I have no idea what…
This is how I save:
def save(self, path, overwrite=True):
if not self.can_checkpoint:
return
fs = self.fs
assert not (fs.exists(path) and not overwrite)
save_obj = dict(
model=self.unwrapped_model.state_dict(),
steps=self.steps.cpu(),
true_steps=self.true_steps.cpu(),
optimizer=self.optimizer.state_dict(),
)
if exists(self.scheduler):
save_obj["scheduler"] = self.scheduler.state_dict()
if self.use_ema:
save_obj["ema"] = self.unwrapped_ema_model.state_dict()
# save to path
with fs.open(path, "wb") as f:
self.accelerator.save(save_obj, f)
self.logger.info(f"checkpoint saved to {path}")
I would greatly appreciate any help… and many TIA