I’m writing my own implementation of llama 2 model in PyTorch and I want to load the weight from hugging face into my model. The problem is that the state_dict
’s don’t match, so I wrote a method for modyfing the state_dict
of the original llama 2 model and loading it with load_state_dict
. My code works, but it takes around few minutes to execute, whereas loading the weights from hugging face’s hub only takes about 30 seconds on my machine. Is there a way I can make my code run faster?
def from_llama_state_dict(llama_state_dict, config, pruning_rates=None):
if isinstance(config, LlamaConfig):
config = LazyLlamaConfig.from_llama_config(pruning_rates, config)
elif not isinstance(config, LazyLlamaConfig):
raise ValueError("Config must be an instance of either LlamaConfig or LazyLlamaConfig.")
new_state_dict = OrderedDict((modify_key(key), value) for key, value in llama_state_dict.items())
model = LazyLlamaForCausalLM(config)
model.load_state_dict(new_state_dict)
return model
def modify_key(key):
if "model.layers" in key:
temp = key.split(".")
temp.insert(3, "decoder")
return ".".join(temp)
else:
return key