Simple quantisation reproduction - how to convert state dict to int8

Hello! I am learning more about quantisation and decided to code it. I am interested in a simple scalar quantisation on the model weights only. My approach is the following:

  • get all the tensors related to model parameters (let them be p)
  • quantise them and save them in int8 (let them be p_quantised)
  • add a pre-forward hook that dequantises the int8 to fp32 (p_quantised → p)
  • add a (post-)forward hook that replaces the dequantised-fp32 with the previous int8 (p → p_quantised)

My main issue is that i am unable to force the model state dict to actually use int8 rather than fp32. So that if i do:

sd = model.state_dict()
p = sd['some_parameter']  # p is in fp32
p_q = quantise(sd['some_paramter']) # p_q is now in int8
sd['some_paramter'] = p_q
model.load_state_dict(sd)

then the model has the values of p_q but they are in fp32. Adding .type(float.int8) or similar did not work. Is there a workaround to this behaviour? What is the suggested path to implement this (just for learning)? Do i need to use torch.quantisation library?

Thanks!

yeah you’ll have change the model.weight from float32 to integer as well for this I think.

not sure what’s your goal but you can also start with using our flows: Quantization — PyTorch main documentation and understand how it works