Loading weights with load_state_dict takes very long

I’m writing my own implementation of llama 2 model in PyTorch and I want to load the weight from hugging face into my model. The problem is that the state_dict’s don’t match, so I wrote a method for modyfing the state_dict of the original llama 2 model and loading it with load_state_dict. My code works, but it takes around few minutes to execute, whereas loading the weights from hugging face’s hub only takes about 30 seconds on my machine. Is there a way I can make my code run faster?

def from_llama_state_dict(llama_state_dict, config, pruning_rates=None):
    if isinstance(config, LlamaConfig):
        config = LazyLlamaConfig.from_llama_config(pruning_rates, config)
    elif not isinstance(config, LazyLlamaConfig):
        raise ValueError("Config must be an instance of either LlamaConfig or LazyLlamaConfig.")
        
    new_state_dict = OrderedDict((modify_key(key), value) for key, value in llama_state_dict.items())

    model = LazyLlamaForCausalLM(config)
    model.load_state_dict(new_state_dict)

    return model
def modify_key(key):
    if "model.layers" in key:
        temp = key.split(".")
        temp.insert(3, "decoder")
        return ".".join(temp)
    else:
        return key