OOM When Resuming From Checkpoint XLA

I was training a GPT-2 XL-sized LLM, and I had to stop the run. When I try to resume the run on the same hardware, I get an OOM. I had a similar issue when my model had about 930m parameters, but I solved it by moving all tensors in the model/optimizer state dicts to CPU before saving. When I run this code:
optimizer.state = collections.defaultdict(dict)
the OOM goes away. The OOM always happens during the optimizer step. I use xm.optimizer_step with the barrier enabled. I have also tried manually sharding the optimizer states using xs.mark_sharding. Here are some details about my project/setup:

  • TPU v3-8

  • Torch 2.7.0

  • jax 0.6.2

  • I use FSDP with SPMD

Here is some relevant code from my codebase:
Saving:

  • def save_checkpoint(model, optimizer, step, train_device_loader=None):
        # Save model weights via XLA SPMD checkpoint (supported)
        os.makedirs(f"./ckpt-{step}", exist_ok=True)
        model_state_dict = model.module.state_dict()
        for i in model_state_dict.keys():
            xla_tensor = model_state_dict[i]
            model_state_dict[i] = xla_tensor.to("cpu")
            del xla_tensor
        model_sd = {"model": model_state_dict}
        xm.save(model_sd, f"./ckpt-{step}/model.pt")
    
        # Save host-only states separately (optimizer, step, RNG, dataloader)
        optim_state = optimizer.state_dict()
        optim_state_for_saving = {
            "state": {},
            "param_groups": optimizer.state_dict()["param_groups"]
        }
        for i in optim_state["state"]:
            optim_state_for_saving["state"][i] = {}
            optim_state_for_saving["state"][i]["step"] = optim_state["state"][i]["step"].to("cpu")
            optim_state_for_saving["state"][i]["exp_avg"] = optim_state["state"][i]["exp_avg"].to("cpu")
            optim_state_for_saving["state"][i]["exp_avg_sq"] = optim_state["state"][i]["exp_avg_sq"].to("cpu")
        host_state = {
            "optim": optim_state_for_saving,
            "step": step,
        }
    
        if train_device_loader:
            rng_states = {
                'torch_rng_state': torch.get_rng_state(),
                'numpy_rng_state': np.random.get_state(),
                'random_rng_state': random.getstate(),
            }
            dataloader_states = {
                "shard_order": train_device_loader._loader.dataset.shards,
                "local_order": train_device_loader._loader.dataset.curr_order,
                "warmup_order": train_device_loader._loader.dataset.warmup_order,
                "warmup_prob": train_device_loader._loader.dataset.warmup_prob,
            }
        else:
            rng_states = None
            dataloader_states = None
    
        # Write host-side files
        with open(f"./ckpt-{step}/host_state.pkl", "wb") as f:
            pickle.dump(host_state, f)
        if rng_states is not None:
            with open(f"./ckpt-{step}/rng.pkl", "wb") as f:
                pickle.dump(rng_states, f)
        if dataloader_states is not None:
            with open(f"./ckpt-{step}/dataloader.json", "w") as json_file:
                json.dump(dataloader_states, json_file, indent=4)
    

    Loading:

    if resume_from != "":
            model_sd = torch.load(f"{resume_from}/model.pt", map_location='cpu')
            model.load_state_dict(model_sd["model"])
    model = model.to(device)
    if gradient_checkpointing:
            model = FSDPv2(module=checkpoint_module(model), mesh=mesh)
    else:
            model = FSDPv2(module=model, mesh=mesh)
    optimizer = build_optimizer(model, peak_lr, betas, weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps*(1-warmup_pct), eta_min=min_lr)
    if resume_from != "":
            xm.mark_step()
            # 2) Restore host-only states (optimizer, step)
            with open(f"{resume_from}/host_state.pkl", 'rb') as f:
                host_state = pickle.load(f)
            optim_state = host_state["optim"]
            
            # Load the processed state dict
            optimizer.load_state_dict(optim_state)
            del optim_state
            last_step = host_state["step"]
            # 3) Restore RNG and dataloader state (if present)
            try:
                with open(f"{resume_from}/rng.pkl", "rb") as f:
                    rng = pickle.load(f)
                torch.set_rng_state(rng['torch_rng_state'])
                np.random.set_state(rng['numpy_rng_state'])
                random.setstate([rng['random_rng_state'][0], tuple(rng['random_rng_state'][1]), rng['random_rng_state'][2]])
            except FileNotFoundError:
                pass
            with open(f'{resume_from}/dataloader.json', 'r') as file:
                dataloader = json.load(file)
    

    Step:

    for k in range(gradient_accumulation_steps):
                        x, y = next(train_iter)
                        with autocast(xm.xla_device(), dtype=torch.bfloat16):
                            loss = model(x, y)
                        (loss / gradient_accumulation_steps).backward()
                        train_loss += loss.detach()
                        xm.mark_step()
                    
    torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                    
    xm.optimizer_step(optimizer, barrier=True)
                    
    optimizer.zero_grad()