Discrepancy between loading models with meta tensors and normal load_from_state_dict

The PyTorch torch.Tensor.is_meta — PyTorch 1.13 documentation “meta” device was introduced a while back to load models efficiently into memory. More precisely it allows one to not have to allocate memory before loading the actual model weights, i.e.:

import torch
layer = torch.nn.Linear(100, 100)

allocates memory while

layer = torch.nn.Linear(100, 100, device="meta")

doesn’t. This is extremely important for the ML community because language models (GPT-J, BLOOM) have become so huge that even CPU RAM can be a limiting factor.

Another super nice additional gain of device="meta" is that it doesn’t run any weight initialization code which can save up to 95% of time when loading models.

Hence, loading models with the meta device has already been widely adopted by the community:

The problem is that there is currently a bit of a discrepancy of dtype casting between loading a model with device="meta" and without. Consider the following code snippet:

import os
import torch
from torch import nn
import tempfile

print("1. save & load raw dict in fp16 precision")
with tempfile.TemporaryDirectory() as tmpdirname:
    layer = nn.Linear(100, 100)
    torch.save(layer.state_dict(), os.path.join(tmpdirname, "fp32.bin"))
    fp16_dict = torch.load("fp16.bin")

with tempfile.TemporaryDirectory() as tmpdirname:
    layer = nn.Linear(100, 100).half()
    torch.save(layer.state_dict(), "fp16.bin")
    fp32_dict = torch.load("fp32.bin")

print("2. load weights into fp32 dtype")
print(25 * "-")
layer_fp32 = nn.Linear(100, 100)
print("Fp32 into Fp32 =>", next(iter(layer_fp32.parameters())).dtype)  # => fp32

layer_fp32 = nn.Linear(100, 100)
print("Fp16 into Fp32 =>", next(iter(layer_fp32.parameters())).dtype)  # => fp32

print("3. load weights into fp16 dtype")
print(25 * "-")
layer_fp16 = nn.Linear(100, 100, dtype=torch.float16)  # => fp16
print("Fp32 into Fp16 =>", next(iter(layer_fp16.parameters())).dtype)

layer_fp16 = nn.Linear(100, 100, dtype=torch.float16)  # => fp16
print("Fp16 into Fp16 =>", next(iter(layer_fp16.parameters())).dtype)

# => Makes sense! Weights are always upcasted / downcasted to "expected random init weights"
# => Now let's check out meta

print("4. load weights into meta")
print(25 * "-")
layer_meta = nn.Linear(100, 100, device="meta")
# Let's load weights as before
layer_meta.load_state_dict(fp32_dict)  # this is essentially a no-op since weights aren't filled; why? is this supposed to stay?
print("Show loaded weigths when using state_dict: ", next(iter(layer_meta.parameters())))

# Ok doesn't work, let's instead set the weights as params (think that's how it's done in accelerate)
for key, value in fp16_dict.items():
    param_cls = type(layer_meta._parameters[key])
    new_value = param_cls(value, **layer_meta._parameters[key].__dict__)
    layer_meta._parameters[key] = new_value
print("Fp16 into meta =>", next(iter(layer_meta.parameters())).dtype) # => fp16

layer_meta = nn.Linear(100, 100, device="meta")
for key, value in fp32_dict.items():
    param_cls = type(layer_meta._parameters[key])
    new_value = param_cls(value, **layer_meta._parameters[key].__dict__)
    layer_meta._parameters[key] = new_value
print("Fp32 into meta =>", next(iter(layer_meta.parameters())).dtype) # => fp32

If you execute this code, you can see that the dtype is always casted to the dtype expected by the randomly initialized weights => this makes sense. The problem is when you make use of device="meta" the are no randomly initialized weights really so what API / behavior should be adopted here?

Some possible solutions that don’t seem to work currently:

  • 1.) Actually allow to load a state_dict into a module that has device="meta" weights. E.g. this codesnippet layer_meta.load_state_dict(fp32_dict) is currently a no-op - is the plan to change this? When doing so should maybe the dtype of the “meta” weight also define the dtype of the loaded weights? To be more precise when doing:
layer_meta = nn.Linear(100, 100, device="meta")

the module layer_meta actually has a dtype which defaults to fp32 just like “normal” modules. The problem however is that load_state_dict on this module doesn’t work libraries currently cannot adopt this approach. Also, would it make sense to do this? If a module doesn’t initialize random weights are things like module.dtype meaningful? Because the dtype really is an attribute of the attached weights and there are no weights by definition, so does it make sense to actually use the dtype of a “meta” tensor?

  • 2.) The loaded weights define what dtype the model should have. If I instantiate a module that allocates no weights, no memory with torch.meta and then load weights in fp16 into such a module, should the module maybe just take this dtype? This seems to make sense to me, but the big drawback is that it would be a bit inconsistent with the existing API and people will see different results depending on whether the model was initialized with device="meta" or not
  • 3.) Not do anything since device="meta" is still pretty new and has no official docs yet really (but as said before it’s already widely used by the community)

To throw in another pointer that might be helpful. In JAX/Flax the weights are not attached to the module and only forwarded to the model. Therefore only the loaded weights decide the dtype of the model. There is no problematic question of whether the “random model weights” should be upcasted to the “loaded model weights” or the other way around.