Hello! I am learning more about quantisation and decided to code it. I am interested in a simple scalar quantisation on the model weights only. My approach is the following:

- get all the tensors related to model parameters (let them be p)
- quantise them and save them in
`int8`

(let them be p_quantised) - add a pre-forward hook that dequantises the
`int8`

to`fp32`

(p_quantised → p) - add a (post-)forward hook that replaces the dequantised-
`fp32`

with the previous int8 (p → p_quantised)

My main issue is that i am unable to force the model state dict to actually use int8 rather than `fp32`

. So that if i do:

```
sd = model.state_dict()
p = sd['some_parameter'] # p is in fp32
p_q = quantise(sd['some_paramter']) # p_q is now in int8
sd['some_paramter'] = p_q
model.load_state_dict(sd)
```

then the model has the values of p_q but they are in `fp32`

. Adding .type(float.int8) or similar did not work. Is there a workaround to this behaviour? What is the suggested path to implement this (just for learning)? Do i need to use torch.quantisation library?

Thanks!