I’ve came across with the following code. As pointed out I perform the matrix multiplication with nn.Linear.
What do you think?
import torch
import torch.nn as nn
class BatchedMatMul(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.linear = nn.Linear(3,3, bias=False)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, input1, input2):
y = []
for b in range(input1.shape[0]):
print(f"Linear's type: {type(self.linear)}")
print(f"Linear's weigth type: {type(self.linear.weight)}")
if isinstance(self.linear.weight, nn.Parameter):
self.linear.weight.requires_grad = False
self.linear.weight.copy_ (self.quant(input1[b]))
y.append(self.linear(self.quant(input2[b])))
else:
scale = self.linear.weight().q_per_channel_scales()
zero_point = self.linear.weight().q_per_channel_zero_points()
w = torch.quantize_per_channel(input1[b], scale, zero_point, 1, torch.qint8)
self.linear.set_weight_bias(w, b=None)
y.append(self.linear(self.quant(input2[b])))
return self.dequant(torch.stack(y))
print("Cronstruct model...")
matmul = BatchedMatMul()
print("Cronstruct model... [OK]")
matmul.eval()
print("Running FP32 inference...")
inp = torch.ones(3, 3).repeat(2,1,1)
y = matmul(2*inp, inp)
print("FP32 output...")
print(y)
print("Running FP32 inference... [OK]")
print("Quantizing...")
matmul.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
matmul_prepared = torch.quantization.prepare(matmul)
matmul_prepared(2*inp, inp)
model_int8 = torch.quantization.convert(matmul_prepared)
print("Quantizing... [OK]")
print("Running INT8 inference...")
y = model_int8.forward(2*inp, inp)
print("Int8 Output")
print(y)
print("Running INT8 inference..[OK]")
Output:
Cronstruct model...
Cronstruct model... [OK]
Running FP32 inference...
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
FP32 output...
tensor([[[6., 6., 6.],
[6., 6., 6.],
[6., 6., 6.]],
[[6., 6., 6.],
[6., 6., 6.],
[6., 6., 6.]]])
Running FP32 inference... [OK]
Quantizing...
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Quantizing... [OK]
Running INT8 inference...
Linear's type: <class 'torch.nn.quantized.modules.linear.Linear'>
Linear's weigth type: <class 'method'>
Linear's type: <class 'torch.nn.quantized.modules.linear.Linear'>
Linear's weigth type: <class 'method'>
Int8 Output
tensor([[[5.9695, 5.9695, 5.9695],
[5.9695, 5.9695, 5.9695],
[5.9695, 5.9695, 5.9695]],
[[5.9695, 5.9695, 5.9695],
[5.9695, 5.9695, 5.9695],
[5.9695, 5.9695, 5.9695]]])
Running INT8 inference..[OK]
/usr/local/lib/python3.6/dist-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
reduce_range will be deprecated in a future release of PyTorch."