How to use the `meta` device to efficiently load large models onto the GPU?

Hey there,

I’d like to leverage the meta device to efficiently load large models (i.e. MPT-30B, etc) onto my GPU, without first creating the model on the CPU and then doing a copy. I would also like to do quantization on the fly :slight_smile: so this seems like a necessary first step.

My understanding is that the meta device should be used for this purpose, but I’m not sure how to proceed.

I did trace through the accelerate implementation:

but it’s unclear to me why the init_on_device function is doing all this work.

Is there a simple clear implementation that I could trace through to understand all the steps properly.


Hey - just wanted to bump this :slight_smile:

I’m trying to understand how to properly use the meta device to load models directly on the GPU… Could you please point me to a clear example. thanks!

Just wanted to this up, as i’m still stuck on it. thank you