Hello,
I am experimenting with quantization of LSTM weights. The way i am doing it is as follows:

• Get the state_dict
• Quantize its values (tensors)

During quantization step I am changing the dtype to torch.uint8 and it is getting reflected. The data type remains float32 when i load it back.

Here is my code:

``````import torch.nn as nn
import torch

class LSTMQuantizer:
def __init__(self, scale=0.1, zero_point=10):
self.scale = scale
self.zero_point = zero_point
self.dtype = torch.qint8

def quantize_lstm(self, lstm_layer: nn.LSTM):
"""
"""
if(isinstance(lstm_layer, nn.LSTM)):
# dictionary of weights
d = lstm_layer.state_dict()
for key in d.keys():
d[key] = self._quantize(d[key])
print(d)
else:
print("TODO: raise exception in quantize_lstm")

def dequantize_lstm(self, lstm_layer: nn.LSTM):

if(isinstance(lstm_layer, nn.LSTM)):
# dictionary of weights
d = lstm_layer.state_dict()
for key in d.keys():
d[key] = self._dequantize(d[key])
else:
print("TODO: raise exception in dequantize_lstm")

def _quantize(self, t: torch.tensor):
"""
Apply quantization, convert data type to uint8 and return
the tensor
"""
t.apply_(lambda i: round((i/self.scale) + self.zero_point))
# t.type(torch.uint8)
t = t.to(torch.uint8)
# print(t)
return t

def _dequantize(self, t: torch.tensor):
"""
Convert the data type to float32 apply Dequantization and return
the tensor
"""
t = t.to(torch.float)
t.apply_(lambda i: (i - self.zero_point) * self.scale)
# print(t)
return t

if __name__ == "__main__":
m = nn.LSTM(1, 1)
# print("Before")
# print(m.state_dict())
q = LSTMQuantizer(0.23, 11)

q.quantize_lstm(m)
print("Quantization")
print(m.state_dict())

# q.dequantize_lstm(m)
# print("Dequantization")
# print(m.state_dict())

``````

Kindly suggest whether it is possible or not.

There are several things that I think are off here:

1. I am not 100% sure if the torch kernels support the uint8 operations outside the `QuantizedCPU` dispatch. In your code, you are quantizing the values manually, and storing them as `torch.uint8` dtype. This means, there must be a `CPU` dispatch for the `uint8` dtype – not sure that’s true.
2. You are losing the information about the scale and zero_point after you quantize the values: How would a kernel know what scales and zero_points to use

Instead, what you could use either of the following:

1. QAT if you want to use “`FakeQuantize``” – this module allows you to use FP kernels, but simulate the quantized operation
2. PTQ if you want to quantize the modules completely.

Please, read through this blog post for more details: Practical Quantization in PyTorch | PyTorch