Is there an alternative to do batched matrix multiplication on Quantized Tensors?

Hi,

I am trying to do post training static quantization, however, I am running into issues where certain operations are not defined for QuantizedCPUTensorId.

Minimal reproducible example:

>>> import torch
>>> 
>>> A = torch.Tensor([[2,2], [3,3]]).unsqueeze(0)
>>> B = torch.Tensor([[2,3], [2,3]]).unsqueeze(0)
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qA = torch.quantize_per_tensor(A, scale, zero_point, dtype)
>>> qB = torch.quantize_per_tensor(B, scale, zero_point, dtype)
>>> torch.matmul(A,B)
tensor([[[ 8., 12.],
         [12., 18.]]])
>>> torch.matmul(qA,qB)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Could not run 'aten::bmm' with arguments from the 'QuantizedCPUTensorId' backend. 'aten::bmm' is only available for these backends: [CPUTensorId, VariableTensorId].

Are there alternatives to accomplishing the same?
I know there are certain operations that are defined here: https://pytorch.org/docs/stable/quantization.html#floatfunctional but what would be the optimal way?

If possible try using nn.Linear instead of aten::bmm.

Currently the only way is to implement the quantized operator for aten::bmm.
One easy way could be by implementing the quantized::linear operator by looping over the batch dimension. We will be looking into implementing this operator in the future.

Hi @supriyar, thanks for the response.

Yes, I had thought about that but wouldn’t that operation be suboptimal? However, if there is no alternative, I guess it would have to be so for now.

Seems like https://pytorch.org/docs/stable/quantization.html#torch.nn.quantized.functional.linear is not a viable option. It requires the input tensor to be unsigned, however, the operation explicitly is between two tensors that are qint8.

>>> torch.nn.quantized.functional.linear(qA[0,], qB[0,])
RuntimeError: expected scalar type QUInt8 but found QInt8

do you need both of the inputs to be qint8? If you change qA to be quint8 it would work