Dequantize tensors from int8 to fp16

I have loaded an LLM in huggingface with load_in_8bit=True.
I noticed the objects in the state_dict are structured something like

  1. model.layers.18.self_attn.k_proj.weight
  2. model.layers.18.self_attn.k_proj.SCB
  3. model.layers.18.self_attn.k_proj.weight_format

The SCB and weight_format are present only in the quantized model. I think SCB refers to scale and bias that can help us in recreating the original tensor? Weight format is just a string that says “row”

I am not sure about the exact method to dequantize, but I tried the following:
(weight_SCB.unsqueeze(1) * weight)/127
This is giving a tensor that is close to the original model (loading without load_in_8bit=True)
However it is not the same.
I think I am doing something wrong in the dequantization process. Would be great if someone could point me to some code or documentation on how I can recreate the exact original tensor (alternatives to huggingface work as well) from the weights.

As a follow up question, I know that for some models there are outlier values that are not quantized even though other values in the tensor are quantized. However I could not find this information in the state_dict. How can we find and handle these values during the dequantization process?

Hi is this a huggingface quantized model? can you show all the code you used to quantize or load the model?

@jerryzh168 yes. This is the code I used:

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", cache_dir="/path/to/dir", device_map="cuda:0", load_in_8bit=True)
state_dict = model.state_dict()

I see, I’m not sure how they are quantizing the model, maybe you can open an issue in their github? GitHub - huggingface/transformers: 🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.