What is `model.device` in FSDP?

I am finetuning a llama model using FSDP on 2 GPUs. I am debugging an error, so I am printing model.device during runs. Unintuitively, on different iterations, different device ids are printed as model.device (sometimes 1 and sometimes 0). This has got me to wonder what actually is model.device when using FSDP. Because the model is sharded across both devices, what determines this value? Furthermore, if I want to know on which device my input layer is sharded to (since that should be the actual model device), how do I do that? How do I ensure that I am casting my inputs to the correct device id (or how does torch does this if it’s done automatically)? I am running into obscure illegal memory access errors and I need to get to the root of this.

Did you call something like torch.cuda.set_device(local_rank)?

As far as I know, FullyShardedDataParallel instances do not have a .device attribute. Do you mean .compute_device? If so, the compute_device should be set at FSDP initialization time and never changed thereafter.