Hello! I am learning more about quantisation and decided to code it. I am interested in a simple scalar quantisation on the model weights only. My approach is the following:
- get all the tensors related to model parameters (let them be p)
- quantise them and save them in
int8
(let them be p_quantised) - add a pre-forward hook that dequantises the
int8
tofp32
(p_quantised → p) - add a (post-)forward hook that replaces the dequantised-
fp32
with the previous int8 (p → p_quantised)
My main issue is that i am unable to force the model state dict to actually use int8 rather than fp32
. So that if i do:
sd = model.state_dict()
p = sd['some_parameter'] # p is in fp32
p_q = quantise(sd['some_paramter']) # p_q is now in int8
sd['some_paramter'] = p_q
model.load_state_dict(sd)
then the model has the values of p_q but they are in fp32
. Adding .type(float.int8) or similar did not work. Is there a workaround to this behaviour? What is the suggested path to implement this (just for learning)? Do i need to use torch.quantisation library?
Thanks!