Is there an alternative to do batched matrix multiplication on Quantized Tensors?

Hi,

I am trying to do post training static quantization, however, I am running into issues where certain operations are not defined for QuantizedCPUTensorId.

Minimal reproducible example:

>>> import torch
>>> 
>>> A = torch.Tensor([[2,2], [3,3]]).unsqueeze(0)
>>> B = torch.Tensor([[2,3], [2,3]]).unsqueeze(0)
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qA = torch.quantize_per_tensor(A, scale, zero_point, dtype)
>>> qB = torch.quantize_per_tensor(B, scale, zero_point, dtype)
>>> torch.matmul(A,B)
tensor([[[ 8., 12.],
         [12., 18.]]])
>>> torch.matmul(qA,qB)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Could not run 'aten::bmm' with arguments from the 'QuantizedCPUTensorId' backend. 'aten::bmm' is only available for these backends: [CPUTensorId, VariableTensorId].

Are there alternatives to accomplishing the same?
I know there are certain operations that are defined here: https://pytorch.org/docs/stable/quantization.html#floatfunctional but what would be the optimal way?

If possible try using nn.Linear instead of aten::bmm.

Currently the only way is to implement the quantized operator for aten::bmm.
One easy way could be by implementing the quantized::linear operator by looping over the batch dimension. We will be looking into implementing this operator in the future.

Hi @supriyar, thanks for the response.

Yes, I had thought about that but wouldn’t that operation be suboptimal? However, if there is no alternative, I guess it would have to be so for now.

Seems like https://pytorch.org/docs/stable/quantization.html#torch.nn.quantized.functional.linear is not a viable option. It requires the input tensor to be unsigned, however, the operation explicitly is between two tensors that are qint8.

>>> torch.nn.quantized.functional.linear(qA[0,], qB[0,])
RuntimeError: expected scalar type QUInt8 but found QInt8

do you need both of the inputs to be qint8? If you change qA to be quint8 it would work

Any update about this one? Are you going to support it in the near future?

I’ve came across with the following code. As pointed out I perform the matrix multiplication with nn.Linear.

What do you think?

import torch
import torch.nn as nn

class BatchedMatMul(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear = nn.Linear(3,3, bias=False)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, input1, input2):
        y = []
        for b in range(input1.shape[0]):
          print(f"Linear's type: {type(self.linear)}")
          print(f"Linear's weigth type: {type(self.linear.weight)}")
          if isinstance(self.linear.weight, nn.Parameter):
            self.linear.weight.requires_grad = False
            self.linear.weight.copy_ (self.quant(input1[b]))
            y.append(self.linear(self.quant(input2[b])))
          else:
            scale = self.linear.weight().q_per_channel_scales()
            zero_point = self.linear.weight().q_per_channel_zero_points()
            w = torch.quantize_per_channel(input1[b], scale, zero_point, 1, torch.qint8)
            self.linear.set_weight_bias(w, b=None)
            y.append(self.linear(self.quant(input2[b])))
          
        return self.dequant(torch.stack(y))

print("Cronstruct model...")
matmul = BatchedMatMul()
print("Cronstruct model... [OK]")

matmul.eval()
print("Running FP32 inference...")
inp = torch.ones(3, 3).repeat(2,1,1)
y = matmul(2*inp, inp)
print("FP32 output...")
print(y)
print("Running FP32 inference... [OK]")

print("Quantizing...")
matmul.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
matmul_prepared = torch.quantization.prepare(matmul)
matmul_prepared(2*inp, inp)
model_int8 = torch.quantization.convert(matmul_prepared)
print("Quantizing... [OK]")
print("Running INT8 inference...")
y = model_int8.forward(2*inp, inp)
print("Int8 Output")
print(y)
print("Running INT8 inference..[OK]")

Output:

Cronstruct model...
Cronstruct model... [OK]
Running FP32 inference...
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
FP32 output...
tensor([[[6., 6., 6.],
         [6., 6., 6.],
         [6., 6., 6.]],

        [[6., 6., 6.],
         [6., 6., 6.],
         [6., 6., 6.]]])
Running FP32 inference... [OK]
Quantizing...
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Quantizing... [OK]
Running INT8 inference...
Linear's type: <class 'torch.nn.quantized.modules.linear.Linear'>
Linear's weigth type: <class 'method'>
Linear's type: <class 'torch.nn.quantized.modules.linear.Linear'>
Linear's weigth type: <class 'method'>
Int8 Output
tensor([[[5.9695, 5.9695, 5.9695],
         [5.9695, 5.9695, 5.9695],
         [5.9695, 5.9695, 5.9695]],

        [[5.9695, 5.9695, 5.9695],
         [5.9695, 5.9695, 5.9695],
         [5.9695, 5.9695, 5.9695]]])
Running INT8 inference..[OK]
/usr/local/lib/python3.6/dist-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  reduce_range will be deprecated in a future release of PyTorch."

currently we only support quint8 for activations and qint8 for weight I think.

Currently we do not have plans for supporting bmm, one workaround is to put DeQuantStub and QuantStub around bmm op to skip quantizing it.