How could I save state dict in int8 format and load back to float32?


I would like to use quantization method to save my model state dict with int8 data type so that I could save my storage space. And I also need to load these int 8 states back to a float 32 model for inference. How could I do this please ?

This isn’t a standard flow PyTorch quantization provides, but you could do something like this:

  1. for a Tensor, use torch.quantize_per_tensor(x, ...) to convert fp32 -> int8, and x.dequantize() to convert from int8 to fp32.
  2. override the _save_to_state_dict and _load_from_state_dict functions on the modules you’d like to do this on to use your custom logic. You can call quantize_per_tensor when saving, and convert it back with dequantize when loading.


Thanks for replying !!! There is a scale and bias term in the function torch.quantize_per_tensor. When I convert float32 parameter to int8 parameter, how could I determine these args? Is there any inner processing method, or should I figure out the max/min values of each tensor myself ?

You can use pytorch/ at master · pytorch/pytorch · GitHub, this is also what the official quantization flows use. For example,

obs = torch.quantization.MinMaxObserver(...)
# run data through the observer
# get scale and zp from the data observer has seen
scale, zp = obs.calculate_qparams()

One other thing to consider could be using fp16, that way you could just do x.half() without worrying about quantization params.