FSDP for multi-gpu encounters ValueError: Inconsistent compute device and `device_id` on rank 1: cuda:0 vs cuda:1

Hi,

I’m experimenting with FSDP by creating a simple MLP with two FC layers:

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(1000, 35000),
            nn.Linear(35000, 1000)
        )
        
    def forward(self, x):
        return self.fc_layers(x)

My testing environment is: (This is just Kaggle)

  • Python 3.10.12
  • Pytorch: 2.5.1+cu121
  • Nodes: 1
  • Procs/ node: 2
  • RAM: 29GB
  • VRAM: 15GB/ GPU.

Hence, the VRAM requirement for training the MLP above is:

(1001 * 35000 + 35001 * 1000) * 20 * 10 / 10^9 ~ 14GB

where:

  • Parameters are float32 (x4)
  • Gradients are float32 (x4)
  • Optimizer states: Parameter copy, momentum, and variance all in float32 (x12)
    => x20.
  • Batch size of 10.

This essentially fits in 1 card, and indeed, when I run the training script on Kaggle, the training procedure completes with decreasing loss throughout iterations.

Now, when I change the MLP to:

self.fc_layers = nn.Sequential(
    nn.Linear(1000, 70000),
    nn.Linear(70000, 1000)
)

the VRAM requirement increases to approximately 28GB if computed similarly as above, which should be okay in the 2-GPU env. But when I run the training script, an error comes up:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/kaggle/working/learn_dist/learn_dist/train.py", line 55, in <module>
[rank1]:     fsdp_model = FullyShardedDataParallel(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
[rank1]:     _init_param_handle_from_module(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_init_utils.py", line 618, in _init_param_handle_from_module
[rank1]:     state.compute_device = _get_compute_device(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_init_utils.py", line 1082, in _get_compute_device
[rank1]:     raise ValueError(
[rank1]: ValueError: Inconsistent compute device and `device_id` on rank 1: cuda:0 vs cuda:1
Epoch 1/2
Processing batch 0 ...

So this error occurs when initializing the FullyShardedDataParallel class, even before the first iteration. I’m confused because the error didn’t occur in the previous MLP configuration, but it appears in this configuration. This is the training code that I experimented with:

if __name__ == "__main__":

    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    device = torch.device(f"cuda:{rank}")
    # torch.cuda.set_device(torch.cuda.current_device())
    dist.init_process_group(backend = "nccl")

    model = MLP().to(device)
    fsdp_model = FullyShardedDataParallel(
        model,
        device_id = device,
        auto_wrap_policy = size_based_auto_wrap_policy,
        cpu_offload = CPUOffload(offload_params = True)
    )
    optim = torch.optim.Adam(fsdp_model.parameters(), lr = 0.0001)
    custom_dataset = gen_data(100)
    dataloader = DataLoader(custom_dataset, batch_size = 10, shuffle = True)
    
    for i in range(2):
        print(f"Epoch {i+1}/2")
        fsdp_model.train()
        for batch_idx, (ts_batch, label_batch) in enumerate(dataloader):
            print(f"Processing batch {batch_idx} ...")
            ts_batch = ts_batch.to(device)
            label_batch = label_batch.to(device)
            optim.zero_grad()
            logits = fsdp_model(ts_batch)
            loss = F.cross_entropy(logits, label_batch)
            print(f"loss = {loss.item()}")
            loss.backward()
            optim.step()

where I wrap the model by the FullyShardedDataParallel class with a size-based policy and CPU offload. I wonder if this error has anything to do with OOM issues or did I not set up FSDP correctly? I’d love to hear your help on this scenario.

Thanks.

P/S: You can fully reproduce the results above by opening a Kaggle notebook (or any other environment that has the same settings) and cloning my toy repo to run the bash script.