 # Is there an alternative to do batched matrix multiplication on Quantized Tensors?

Hi,

I am trying to do post training static quantization, however, I am running into issues where certain operations are not defined for `QuantizedCPUTensorId`.

Minimal reproducible example:

``````>>> import torch
>>>
>>> A = torch.Tensor([[2,2], [3,3]]).unsqueeze(0)
>>> B = torch.Tensor([[2,3], [2,3]]).unsqueeze(0)
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qA = torch.quantize_per_tensor(A, scale, zero_point, dtype)
>>> qB = torch.quantize_per_tensor(B, scale, zero_point, dtype)
>>> torch.matmul(A,B)
tensor([[[ 8., 12.],
[12., 18.]]])
>>> torch.matmul(qA,qB)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Could not run 'aten::bmm' with arguments from the 'QuantizedCPUTensorId' backend. 'aten::bmm' is only available for these backends: [CPUTensorId, VariableTensorId].
``````

Are there alternatives to accomplishing the same?
I know there are certain operations that are defined here: https://pytorch.org/docs/stable/quantization.html#floatfunctional but what would be the optimal way?

If possible try using `nn.Linear` instead of `aten::bmm`.

Currently the only way is to implement the quantized operator for `aten::bmm`.
One easy way could be by implementing the `quantized::linear` operator by looping over the batch dimension. We will be looking into implementing this operator in the future.

Hi @supriyar, thanks for the response.

Yes, I had thought about that but wouldn’t that operation be suboptimal? However, if there is no alternative, I guess it would have to be so for now.

Seems like https://pytorch.org/docs/stable/quantization.html#torch.nn.quantized.functional.linear is not a viable option. It requires the input tensor to be unsigned, however, the operation explicitly is between two tensors that are `qint8`.

``````>>> torch.nn.quantized.functional.linear(qA[0,], qB[0,])
RuntimeError: expected scalar type QUInt8 but found QInt8
``````

do you need both of the inputs to be `qint8`? If you change `qA` to be `quint8` it would work