Hello,
I am experimenting with quantization of LSTM weights. The way i am doing it is as follows:
- Get the state_dict
- Quantize its values (tensors)
- load it back
During quantization step I am changing the dtype to torch.uint8 and it is getting reflected. The data type remains float32 when i load it back.
Here is my code:
import torch.nn as nn
import torch
class LSTMQuantizer:
def __init__(self, scale=0.1, zero_point=10):
self.scale = scale
self.zero_point = zero_point
self.dtype = torch.qint8
def quantize_lstm(self, lstm_layer: nn.LSTM):
"""
"""
if(isinstance(lstm_layer, nn.LSTM)):
# dictionary of weights
d = lstm_layer.state_dict()
for key in d.keys():
d[key] = self._quantize(d[key])
print(d)
# d._metadata[""]["version"] = 2
lstm_layer.load_state_dict(d)
else:
print("TODO: raise exception in quantize_lstm")
def dequantize_lstm(self, lstm_layer: nn.LSTM):
if(isinstance(lstm_layer, nn.LSTM)):
# dictionary of weights
d = lstm_layer.state_dict()
for key in d.keys():
d[key] = self._dequantize(d[key])
lstm_layer.load_state_dict(d)
else:
print("TODO: raise exception in dequantize_lstm")
def _quantize(self, t: torch.tensor):
"""
Apply quantization, convert data type to uint8 and return
the tensor
"""
t.apply_(lambda i: round((i/self.scale) + self.zero_point))
# t.type(torch.uint8)
t = t.to(torch.uint8)
# print(t)
return t
def _dequantize(self, t: torch.tensor):
"""
Convert the data type to float32 apply Dequantization and return
the tensor
"""
t = t.to(torch.float)
t.apply_(lambda i: (i - self.zero_point) * self.scale)
# print(t)
return t
if __name__ == "__main__":
m = nn.LSTM(1, 1)
# print("Before")
# print(m.state_dict())
q = LSTMQuantizer(0.23, 11)
q.quantize_lstm(m)
print("Quantization")
print(m.state_dict())
# q.dequantize_lstm(m)
# print("Dequantization")
# print(m.state_dict())
Kindly suggest whether it is possible or not.