Change Linear layer weights in a quantized model

Since matrix multiplication is not supported for model quantization I’m performing it with a nn.Linear layer which I change its weigths in every forward pass.

This approach works well for the FP32 model but it crashes when the model is quantized. The issue is that, when the model is converted to int8, the following lines of code are not valid

self.linear.weight.requires_grad = False
self.linear.weight.copy_ (input1[b])

because in the converted model self.linear.weight is not a torch.nn.Paramater but a method which returns a Tensor

Any workarround on this?

FULL CODE

import torch

import torch.nn as nn

class BatchedMatMul(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear = nn.Linear(3,3, bias=False)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, input1, input2):
        y = []
        for b in range(input1.shape[0]):
          print(f"Linear's type: {type(self.linear)}")
          print(f"Linear's weigth type: {type(self.linear.weight)}")
          self.linear.weight.requires_grad = False
          self.linear.weight.copy_ (self.quant(input1[b]))
          y.append(self.linear(self.quant(input2[b])))
        return self.dequant(torch.stack(y))

print("Cronstruct model...")

matmul = BatchedMatMul()

print("Cronstruct model... [OK]")

matmul.eval()

print("Running FP32 inference...")

inp = torch.ones(3, 3).repeat(2,1,1)

y = matmul(inp, inp)

print(y)

print("Running FP32 inference... [OK]")

print("Quantizing...")

matmul.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

matmul_prepared = torch.quantization.prepare(matmul)

matmul_prepared(inp, inp)

model_int8 = torch.quantization.convert(matmul_prepared)

print("Quantizing... [OK]")

print("Running INT8 inference...")

y = model_int8(inp, inp)

print(y)

print("Running INT8 inference..[OK]")

OUTPUT

Cronstruct model...
Cronstruct model... [OK]
Running FP32 inference...
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
tensor([[[3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.]],

        [[3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.]]])
Running FP32 inference... [OK]
Quantizing...
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Quantizing... [OK]
Running INT8 inference...
Linear's weigth type: <class 'method'>
/usr/local/lib/python3.6/dist-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  reduce_range will be deprecated in a future release of PyTorch."
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-81-024fd82f94de> in <module>()
     34 print("Quantizing... [OK]")
     35 print("Running INT8 inference...")
---> 36 y = model_int8(inp, inp)
     37 print(y)
     38 print("Running INT8 inference..[OK]")

1 frames
<ipython-input-81-024fd82f94de> in forward(self, input1, input2)
     10         for b in range(input1.shape[0]):
     11           print(f"Linear's weigth type: {type(self.linear.weight)}")
---> 12           self.linear.weight.requires_grad = False
     13           self.linear.weight.copy_ (input1[b])
     14           y.append(self.linear(input2[b]))

AttributeError: 'method' object has no attribute 'requires_grad'

I come acrross with the following:

class BatchedMatMul(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear = nn.Linear(3,3, bias=False)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, input1, input2):
        y = []
        for b in range(input1.shape[0]):
          print(f"Linear's type: {type(self.linear)}")
          print(f"Linear's weigth type: {type(self.linear.weight)}")
          if isinstance(self.linear.weight, nn.Parameter):
            self.linear.weight.requires_grad = False
            self.linear.weight.copy_ (self.quant(input1[b]))
            y.append(self.linear(self.quant(input2[b])))
          else:
            self.linear.set_weight_bias(self.quant(input1[b]), b=None)
            y.append(self.linear(self.quant(input2[b])))
          
        return self.dequant(torch.stack(y))

self.linear has changeg from torch.nn.modules.linear.Linear to torch.nn.quantized.modules.linear.Linear so their methods and attributes are different.

Nevertheless, this approach is still throwing an error because the quantized linear layer expects an signed integer as parameter but an unsigned is being given…

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-112-629b4f9ebdc0> in <module>()
     42 print("Quantizing... [OK]")
     43 print("Running INT8 inference...")
---> 44 y = model_int8.forward(inp, inp)
     45 print(y)
     46 print("Running INT8 inference..[OK]")

2 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/quantized/modules/linear.py in set_weight_bias(self, weight, bias)
     21     def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
     22         if self.dtype == torch.qint8:
---> 23             self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
     24         elif self.dtype == torch.float16:
     25             self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias)

RuntimeError: expected scalar type QInt8 but found QUInt8

I think I’m close… what do you think?

import torch
import torch.nn as nn

class BatchedMatMul(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear = nn.Linear(3,3, bias=False)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, input1, input2):
        y = []
        for b in range(input1.shape[0]):
          print(f"Linear's type: {type(self.linear)}")
          print(f"Linear's weigth type: {type(self.linear.weight)}")
          if isinstance(self.linear.weight, nn.Parameter):
            self.linear.weight.requires_grad = False
            self.linear.weight.copy_ (self.quant(input1[b]))
            y.append(self.linear(self.quant(input2[b])))
          else:
            scale = self.linear.weight().q_per_channel_scales()
            zero_point = self.linear.weight().q_per_channel_zero_points()
            w = torch.quantize_per_channel(input1[b], scale, zero_point, 1, torch.qint8)
            self.linear.set_weight_bias(w, b=None)
            y.append(self.linear(self.quant(input2[b])))
          
        return self.dequant(torch.stack(y))

print("Cronstruct model...")
matmul = BatchedMatMul()
print("Cronstruct model... [OK]")

matmul.eval()
print("Running FP32 inference...")
inp = torch.ones(3, 3).repeat(2,1,1)
y = matmul(2*inp, inp)
print("FP32 output...")
print(y)
print("Running FP32 inference... [OK]")

print("Quantizing...")
matmul.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
matmul_prepared = torch.quantization.prepare(matmul)
matmul_prepared(2*inp, inp)
model_int8 = torch.quantization.convert(matmul_prepared)
print("Quantizing... [OK]")
print("Running INT8 inference...")
y = model_int8.forward(2*inp, inp)
print("Int8 Output")
print(y)
print("Running INT8 inference..[OK]")

OUT

Cronstruct model...
Cronstruct model... [OK]
Running FP32 inference...
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
FP32 output...
tensor([[[6., 6., 6.],
         [6., 6., 6.],
         [6., 6., 6.]],

        [[6., 6., 6.],
         [6., 6., 6.],
         [6., 6., 6.]]])
Running FP32 inference... [OK]
Quantizing...
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Linear's type: <class 'torch.nn.modules.linear.Linear'>
Linear's weigth type: <class 'torch.nn.parameter.Parameter'>
Quantizing... [OK]
Running INT8 inference...
Linear's type: <class 'torch.nn.quantized.modules.linear.Linear'>
Linear's weigth type: <class 'method'>
Linear's type: <class 'torch.nn.quantized.modules.linear.Linear'>
Linear's weigth type: <class 'method'>
Int8 Output
tensor([[[5.9695, 5.9695, 5.9695],
         [5.9695, 5.9695, 5.9695],
         [5.9695, 5.9695, 5.9695]],

        [[5.9695, 5.9695, 5.9695],
         [5.9695, 5.9695, 5.9695],
         [5.9695, 5.9695, 5.9695]]])
Running INT8 inference..[OK]
/usr/local/lib/python3.6/dist-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  reduce_range will be deprecated in a future release of PyTorch."

this feels a bit hacky, I’m not sure if it would work or not, bmm is not supported right now, why not put dequantstub and quantstub around bmm op to avoid quantizing it?

Yes, it’s quite hacky… indeed, it only works for dynamic quantization (quantizing only the weights)

I’m very interested in placing the quantsub and dequantsub to avoid quantization, could you please provide a pice of code?

Do you mean something like this:

# y and x Tensors previouslly created...
x = torch.quantization.DeQuantStub(x)
y = torch.quantization.DeQuantStub(y)
Y = torch.torch.bmm(y, x)
y = torch.quantization.QuantStub(y)

Thx you very much!

yes that’s exactly what I meant, did you try this?

well. since quantization is not yet available for GPU inference… it is not worth for me to try this out…