I was training a GPT-2 XL-sized LLM, and I had to stop the run. When I try to resume the run on the same hardware, I get an OOM. I had a similar issue when my model had about 930m parameters, but I solved it by moving all tensors in the model/optimizer state dicts to CPU before saving. When I run this code:
optimizer.state = collections.defaultdict(dict)
the OOM goes away. The OOM always happens during the optimizer step. I use xm.optimizer_step with the barrier enabled. I have also tried manually sharding the optimizer states using xs.mark_sharding. Here are some details about my project/setup:
-
TPU v3-8
-
Torch 2.7.0
-
jax 0.6.2
-
I use FSDP with SPMD
Here is some relevant code from my codebase:
Saving:
-
def save_checkpoint(model, optimizer, step, train_device_loader=None): # Save model weights via XLA SPMD checkpoint (supported) os.makedirs(f"./ckpt-{step}", exist_ok=True) model_state_dict = model.module.state_dict() for i in model_state_dict.keys(): xla_tensor = model_state_dict[i] model_state_dict[i] = xla_tensor.to("cpu") del xla_tensor model_sd = {"model": model_state_dict} xm.save(model_sd, f"./ckpt-{step}/model.pt") # Save host-only states separately (optimizer, step, RNG, dataloader) optim_state = optimizer.state_dict() optim_state_for_saving = { "state": {}, "param_groups": optimizer.state_dict()["param_groups"] } for i in optim_state["state"]: optim_state_for_saving["state"][i] = {} optim_state_for_saving["state"][i]["step"] = optim_state["state"][i]["step"].to("cpu") optim_state_for_saving["state"][i]["exp_avg"] = optim_state["state"][i]["exp_avg"].to("cpu") optim_state_for_saving["state"][i]["exp_avg_sq"] = optim_state["state"][i]["exp_avg_sq"].to("cpu") host_state = { "optim": optim_state_for_saving, "step": step, } if train_device_loader: rng_states = { 'torch_rng_state': torch.get_rng_state(), 'numpy_rng_state': np.random.get_state(), 'random_rng_state': random.getstate(), } dataloader_states = { "shard_order": train_device_loader._loader.dataset.shards, "local_order": train_device_loader._loader.dataset.curr_order, "warmup_order": train_device_loader._loader.dataset.warmup_order, "warmup_prob": train_device_loader._loader.dataset.warmup_prob, } else: rng_states = None dataloader_states = None # Write host-side files with open(f"./ckpt-{step}/host_state.pkl", "wb") as f: pickle.dump(host_state, f) if rng_states is not None: with open(f"./ckpt-{step}/rng.pkl", "wb") as f: pickle.dump(rng_states, f) if dataloader_states is not None: with open(f"./ckpt-{step}/dataloader.json", "w") as json_file: json.dump(dataloader_states, json_file, indent=4)Loading:
if resume_from != "": model_sd = torch.load(f"{resume_from}/model.pt", map_location='cpu') model.load_state_dict(model_sd["model"]) model = model.to(device) if gradient_checkpointing: model = FSDPv2(module=checkpoint_module(model), mesh=mesh) else: model = FSDPv2(module=model, mesh=mesh) optimizer = build_optimizer(model, peak_lr, betas, weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps*(1-warmup_pct), eta_min=min_lr) if resume_from != "": xm.mark_step() # 2) Restore host-only states (optimizer, step) with open(f"{resume_from}/host_state.pkl", 'rb') as f: host_state = pickle.load(f) optim_state = host_state["optim"] # Load the processed state dict optimizer.load_state_dict(optim_state) del optim_state last_step = host_state["step"] # 3) Restore RNG and dataloader state (if present) try: with open(f"{resume_from}/rng.pkl", "rb") as f: rng = pickle.load(f) torch.set_rng_state(rng['torch_rng_state']) np.random.set_state(rng['numpy_rng_state']) random.setstate([rng['random_rng_state'][0], tuple(rng['random_rng_state'][1]), rng['random_rng_state'][2]]) except FileNotFoundError: pass with open(f'{resume_from}/dataloader.json', 'r') as file: dataloader = json.load(file)Step:
for k in range(gradient_accumulation_steps): x, y = next(train_iter) with autocast(xm.xla_device(), dtype=torch.bfloat16): loss = model(x, y) (loss / gradient_accumulation_steps).backward() train_loss += loss.detach() xm.mark_step() torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) xm.optimizer_step(optimizer, barrier=True) optimizer.zero_grad()