# Is there an alternative to do batched matrix multiplication on Quantized Tensors?

Hi,

I am trying to do post training static quantization, however, I am running into issues where certain operations are not defined for `QuantizedCPUTensorId`.

Minimal reproducible example:

``````>>> import torch
>>>
>>> A = torch.Tensor([[2,2], [3,3]]).unsqueeze(0)
>>> B = torch.Tensor([[2,3], [2,3]]).unsqueeze(0)
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qA = torch.quantize_per_tensor(A, scale, zero_point, dtype)
>>> qB = torch.quantize_per_tensor(B, scale, zero_point, dtype)
>>> torch.matmul(A,B)
tensor([[[ 8., 12.],
[12., 18.]]])
>>> torch.matmul(qA,qB)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Could not run 'aten::bmm' with arguments from the 'QuantizedCPUTensorId' backend. 'aten::bmm' is only available for these backends: [CPUTensorId, VariableTensorId].
``````

Are there alternatives to accomplishing the same?
I know there are certain operations that are defined here: https://pytorch.org/docs/stable/quantization.html#floatfunctional but what would be the optimal way?

If possible try using `nn.Linear` instead of `aten::bmm`.

Currently the only way is to implement the quantized operator for `aten::bmm`.
One easy way could be by implementing the `quantized::linear` operator by looping over the batch dimension. We will be looking into implementing this operator in the future.

Hi @supriyar, thanks for the response.

Yes, I had thought about that but wouldn’t that operation be suboptimal? However, if there is no alternative, I guess it would have to be so for now.

Seems like https://pytorch.org/docs/stable/quantization.html#torch.nn.quantized.functional.linear is not a viable option. It requires the input tensor to be unsigned, however, the operation explicitly is between two tensors that are `qint8`.

``````>>> torch.nn.quantized.functional.linear(qA[0,], qB[0,])
RuntimeError: expected scalar type QUInt8 but found QInt8
``````

do you need both of the inputs to be `qint8`? If you change `qA` to be `quint8` it would work

I’ve came across with the following code. As pointed out I perform the matrix multiplication with nn.Linear.

What do you think?

``````import torch
import torch.nn as nn

class BatchedMatMul(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.linear = nn.Linear(3,3, bias=False)
self.dequant = torch.quantization.DeQuantStub()

def forward(self, input1, input2):
y = []
for b in range(input1.shape[0]):
print(f"Linear's type: {type(self.linear)}")
print(f"Linear's weigth type: {type(self.linear.weight)}")
if isinstance(self.linear.weight, nn.Parameter):
self.linear.weight.copy_ (self.quant(input1[b]))
y.append(self.linear(self.quant(input2[b])))
else:
scale = self.linear.weight().q_per_channel_scales()
zero_point = self.linear.weight().q_per_channel_zero_points()
w = torch.quantize_per_channel(input1[b], scale, zero_point, 1, torch.qint8)
self.linear.set_weight_bias(w, b=None)
y.append(self.linear(self.quant(input2[b])))

return self.dequant(torch.stack(y))

print("Cronstruct model...")
matmul = BatchedMatMul()
print("Cronstruct model... [OK]")

matmul.eval()
print("Running FP32 inference...")
inp = torch.ones(3, 3).repeat(2,1,1)
y = matmul(2*inp, inp)
print("FP32 output...")
print(y)
print("Running FP32 inference... [OK]")

print("Quantizing...")
matmul.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
matmul_prepared = torch.quantization.prepare(matmul)
matmul_prepared(2*inp, inp)
model_int8 = torch.quantization.convert(matmul_prepared)
print("Quantizing... [OK]")
print("Running INT8 inference...")
y = model_int8.forward(2*inp, inp)
print("Int8 Output")
print(y)
print("Running INT8 inference..[OK]")
``````

Output:

``````Cronstruct model...
Cronstruct model... [OK]
Running FP32 inference...
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
FP32 output...
tensor([[[6., 6., 6.],
[6., 6., 6.],
[6., 6., 6.]],

[[6., 6., 6.],
[6., 6., 6.],
[6., 6., 6.]]])
Running FP32 inference... [OK]
Quantizing...
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Quantizing... [OK]
Running INT8 inference...
Linear's type: <class 'torch.nn.quantized.modules.linear.Linear'>
Linear's weigth type: <class 'method'>
Linear's type: <class 'torch.nn.quantized.modules.linear.Linear'>
Linear's weigth type: <class 'method'>
Int8 Output
tensor([[[5.9695, 5.9695, 5.9695],
[5.9695, 5.9695, 5.9695],
[5.9695, 5.9695, 5.9695]],

[[5.9695, 5.9695, 5.9695],
[5.9695, 5.9695, 5.9695],
[5.9695, 5.9695, 5.9695]]])
Running INT8 inference..[OK]
/usr/local/lib/python3.6/dist-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
reduce_range will be deprecated in a future release of PyTorch."
``````

currently we only support quint8 for activations and qint8 for weight I think.

Currently we do not have plans for supporting bmm, one workaround is to put DeQuantStub and QuantStub around bmm op to skip quantizing it.