I have a model consisting of two parts: the first is an encoder of float32, and the second is a quantized LLM.
If I load the LLM in bfloat16, I can do encoder(x).bfloat()
and feed it into the LLM. But for LLM in 8bit, I cannot find a corresponding operation to convert the output of the encoder.