I have a model consisting of two parts: the first is an encoder of float32, and the second is a quantized LLM.
If I load the LLM in bfloat16, I can do encoder(x).bfloat()
and feed it into the LLM. But for LLM in 8bit, I cannot find a corresponding operation to convert the output of the encoder.
if you want to feed a single int8 tensor, you can probably still do encoder(x).char()
I guess? but this is not really quantization, since int8 quantization requires scale/zero_point as well
Hello,
You can manually quantize the encoder output to 8-bit integers using the same method the LLM expects during inference. Libraries like bitsandbytes
or transformers
from Hugging Face often handle these operations under the hood. Another approach is fake quantization, where you simulate the 8-bit range without converting the data type: def fake_quantize(tensor, min_val=-1, max_val=1, levels=256):
step = (max_val - min_val) / (levels - 1)
return torch.round((tensor - min_val) / step) * step + min_val
This approach keeps the data as float32
but simulates 8-bit precision.
4. LLM Library-Specific Hooks:
Check if the LLM’s library offers an input transformation or pre-processor API. Many 8-bit models have specific hooks for handling data types, especially in PyTorch, TensorFlow, or ONNX-based models.
Hope that helps !