I am finetuning a llama model using FSDP on 2 GPUs. I am debugging an error, so I am printing model.device
during runs. Unintuitively, on different iterations, different device ids are printed as model.device
(sometimes 1 and sometimes 0). This has got me to wonder what actually is model.device
when using FSDP. Because the model is sharded across both devices, what determines this value? Furthermore, if I want to know on which device my input layer is sharded to (since that should be the actual model device), how do I do that? How do I ensure that I am casting my inputs to the correct device id (or how does torch does this if it’s done automatically)? I am running into obscure illegal memory access
errors and I need to get to the root of this.
Thanks
Did you call something like torch.cuda.set_device(local_rank)
?
As far as I know, FullyShardedDataParallel
instances do not have a .device
attribute. Do you mean .compute_device
? If so, the compute_device
should be set at FSDP initialization time and never changed thereafter.