Hi,
I’m currently learning how to use FSDP to shard model across gpus, and I have a toy example as follows where I shard a simple Multilayer Perceptron using FullyShardedDataParallel
:
if __name__ == "__main__":
torch.cuda.set_device(torch.cuda.current_device())
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group(backend = "nccl")
model = MLP()
fsdp_model = FullyShardedDataParallel(
model,
device_id = rank,
auto_wrap_policy = size_based_auto_wrap_policy,
cpu_offload = CPUOffload(offload_params = True)
)
optim = torch.optim.Adam(fsdp_model.parameters(), lr = 0.0001)
custom_dataset = gen_data(100)
dataloader = DataLoader(custom_dataset, batch_size = 10, shuffle = True)
device = next(fsdp_model.parameters()).device
for i in range(NUM_EPOCHS):
fsdp_model.train()
for batch_idx, (ts_batch, label_batch) in enumerate(dataloader):
ts_batch = ts_batch.to(device)
label_batch = label_batch.to(device)
optim.zero_grad()
logits = fsdp_model(ts_batch)
loss = F.cross_entropy(logits, label_batch)
loss.backward()
optim.step()
My cluster environment is single-node with 2 A100s, so I use:
torchrun --nproc_per_node=2 --nnodes=1 train.py
My first question is that when I pass device_id
arg in FullyShardedDataParallel()
as torch.cuda.current_device()
, the error Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 7000
occurs so then I set it to the rank which solves the error. The FSDP doc states that this arg specifies the cuda device to initialize model, shard parameters, etc., and they suggest that we may pass torch.cuda.current_device
, but it results in the error above. When I print out both the rank and torch.cuda.current_device
:
In rank 0: cuda.current_device: 0
In rank 1: cuda.current_device: 0
Thus, I’m confused why torch.cuda.curren_device()
returns 0 for both ranks even if they are in the same node. And for the FSDP definition, is passing torch.cuda.current_device
to device_id
arg correct?
For the second question, before the training loop, I attempted to get the device via:
device = next(fsdp_model.parameters()).device
and in the body of training, I move the data tensors to this device. But it turns out that this device is just cpu
, and only after the forward pass, the device of logits
is either cuda:0
or cuda:1
. But to correctly transfer the data tensors to the same device as fsdp_model
, shouldn’t we need the cuda device first?
Thanks!