How to get data on the same device as fsdp model during training?

Hi,

I’m currently learning how to use FSDP to shard model across gpus, and I have a toy example as follows where I shard a simple Multilayer Perceptron using FullyShardedDataParallel:

if __name__ == "__main__":

    torch.cuda.set_device(torch.cuda.current_device())
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    dist.init_process_group(backend = "nccl")

    model = MLP()
    fsdp_model = FullyShardedDataParallel(
        model,
        device_id = rank,
        auto_wrap_policy = size_based_auto_wrap_policy,
        cpu_offload = CPUOffload(offload_params = True)
    )
    optim = torch.optim.Adam(fsdp_model.parameters(), lr = 0.0001)
    custom_dataset = gen_data(100)
    dataloader = DataLoader(custom_dataset, batch_size = 10, shuffle = True)
    device = next(fsdp_model.parameters()).device
    for i in range(NUM_EPOCHS):
        fsdp_model.train()
        for batch_idx, (ts_batch, label_batch) in enumerate(dataloader):
            ts_batch = ts_batch.to(device)
            label_batch = label_batch.to(device)
            optim.zero_grad()
            logits = fsdp_model(ts_batch)
            loss = F.cross_entropy(logits, label_batch)
            loss.backward()
            optim.step()

My cluster environment is single-node with 2 A100s, so I use:

torchrun --nproc_per_node=2 --nnodes=1 train.py

My first question is that when I pass device_id arg in FullyShardedDataParallel() as torch.cuda.current_device(), the error Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 7000 occurs so then I set it to the rank which solves the error. The FSDP doc states that this arg specifies the cuda device to initialize model, shard parameters, etc., and they suggest that we may pass torch.cuda.current_device, but it results in the error above. When I print out both the rank and torch.cuda.current_device:

In rank 0: cuda.current_device: 0
In rank 1: cuda.current_device: 0

Thus, I’m confused why torch.cuda.curren_device() returns 0 for both ranks even if they are in the same node. And for the FSDP definition, is passing torch.cuda.current_device to device_id arg correct?

For the second question, before the training loop, I attempted to get the device via:

device = next(fsdp_model.parameters()).device

and in the body of training, I move the data tensors to this device. But it turns out that this device is just cpu, and only after the forward pass, the device of logits is either cuda:0 or cuda:1. But to correctly transfer the data tensors to the same device as fsdp_model, shouldn’t we need the cuda device first?

Thanks!