Load_state_dict drops the data type

Hello,
I am experimenting with quantization of LSTM weights. The way i am doing it is as follows:

  • Get the state_dict
  • Quantize its values (tensors)
  • load it back

During quantization step I am changing the dtype to torch.uint8 and it is getting reflected. The data type remains float32 when i load it back.

Here is my code:

import torch.nn as nn
import torch

class LSTMQuantizer:
    def __init__(self, scale=0.1, zero_point=10):
        self.scale = scale
        self.zero_point = zero_point 
        self.dtype = torch.qint8

    def quantize_lstm(self, lstm_layer: nn.LSTM):
        """
        """
        if(isinstance(lstm_layer, nn.LSTM)):
            # dictionary of weights
            d = lstm_layer.state_dict()
            for key in d.keys():
                d[key] = self._quantize(d[key])
            print(d)
            # d._metadata[""]["version"] = 2
            lstm_layer.load_state_dict(d)
        else:
            print("TODO: raise exception in quantize_lstm")
        
    def dequantize_lstm(self, lstm_layer: nn.LSTM):

        if(isinstance(lstm_layer, nn.LSTM)):
            # dictionary of weights
            d = lstm_layer.state_dict()
            for key in d.keys():
                d[key] = self._dequantize(d[key])
            lstm_layer.load_state_dict(d)
        else:
            print("TODO: raise exception in dequantize_lstm")        

    def _quantize(self, t: torch.tensor):
        """
        Apply quantization, convert data type to uint8 and return 
        the tensor 
        """
        t.apply_(lambda i: round((i/self.scale) + self.zero_point))
        # t.type(torch.uint8)
        t = t.to(torch.uint8)
        # print(t)
        return t 
    
    def _dequantize(self, t: torch.tensor):
        """
        Convert the data type to float32 apply Dequantization and return 
        the tensor 
        """
        t = t.to(torch.float)
        t.apply_(lambda i: (i - self.zero_point) * self.scale)
        # print(t)
        return t 


if __name__ == "__main__":
    m = nn.LSTM(1, 1)
    # print("Before")
    # print(m.state_dict())
    q = LSTMQuantizer(0.23, 11)

    q.quantize_lstm(m)
    print("Quantization")
    print(m.state_dict())

    # q.dequantize_lstm(m)
    # print("Dequantization")
    # print(m.state_dict())

Kindly suggest whether it is possible or not.

There are several things that I think are off here:

  1. I am not 100% sure if the torch kernels support the uint8 operations outside the QuantizedCPU dispatch. In your code, you are quantizing the values manually, and storing them as torch.uint8 dtype. This means, there must be a CPU dispatch for the uint8 dtype – not sure that’s true.
  2. You are losing the information about the scale and zero_point after you quantize the values: How would a kernel know what scales and zero_points to use

Instead, what you could use either of the following:

  1. QAT if you want to use “`FakeQuantize``” – this module allows you to use FP kernels, but simulate the quantized operation
  2. PTQ if you want to quantize the modules completely.

Please, read through this blog post for more details: Practical Quantization in PyTorch | PyTorch