How to quantize register_parameters or nn.parameters?

I quantized the convolution model with a state tensor.
The state tensor is intended to be used like a queue.

Firstly, I tried that make a qint8 tensor for register_parameter.
But, I got a type error, when running the quantized model in PyTorch and libtorch.

I’m sorry that some of the code below was omitted because i couldn’t copy the entire text dut to some reason.

class conv_with_state(nn.Module):
    def __init__():
	    scale,zero_point = 1e-4,2
	    dtype= torch.qint8
	    float_state_tensor = torch.zeros(state_size,in_channel,1)
	    q_tensor = torch.quantize_per_tensor(float_state_tensor,scale,zero_point_dtype)
	    self.shift_size = 2
	    self.f_cat = nn.quantized.FloatFunctional()
	    self.register_parameter('internal_state',nn.Parameter(int_state_tensor,requires_grad=False))
	    self.conv = nn.Conv2d(--)

    def forward(self,x):
	    self.internal_state[-self.shift_size:].data.copy_(self.internal_state[self.shift_size:])
	    x = self.f_cat.cat(self.internal_state[-self.shift_size:].clone(),x)
	    self.internal_state.data.copy_(x)
	    self.conv(x)

Traceback of TorchScript, original code (most recent call last):
	File "torch/nn/quantized/modules/functional_modules.py", line 174 in forward
	def cat(self,x,dim=0):
	r = ops.quantized.cat(x, scale=self.scale, zero_point=slef.zero_point,dim=dim)
	    ~~~~~~~~~~~~~~~~~ <--- HERE
	r = self.activation_post_process(r)
	return r
RuntimeError: All dtypes must be the same.
Aborted (core dumped)

So I tried to wrap register_parameter with QuantStub.
Similar to the code above, I had a type problem.

class conv_with_state(nn.Module):
    def __init__():
	    self.quant = torch.quantization.QuantStub()
	    float_state_tensor = torch.zeros(state_size,in_channel,1)
	    int_state_tensor = self.quant(float_state_tensor)
        self.register_parameter('internal_state',nn.Parameter(int_state_tensor,requires_grad=False))
	    self.conv = nn.Conv2d(--)

    def forward(self,x):
	    self.internal_state.data.copy_(torch.roll(self.internal_state,-self.shift_size,0))
	    self.internal_state[-self.shift_size:].data.copy_(x)
	    x = self.internal_state
	    self.conv(x)
Traceback (most recent call last):
        self.internal_state[-self.shift_size:].data.copy_(x)
RuntimeError: Copyting from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor

How i got quantized tensor parameters?

1 Like

hi @hwlee , could you post an e2e reproducible example which includes both your model as well as how you are calling the quantization APIs?

High level, the first error means that when you do cat all the tensors have to be of the same dtype (otherwise cat is not defined), and the second error means that you are trying to copy a quantized tensor to a floating point tensor, which is also not defined. It’s hard to recommend a fix without some more context, i.e. and e2e example.

Hi @Vasiliy_Kuznetsov , thank you for your reply

I add e2e reproducibel example below.

My first try

This code is based on non-quantize code version. I just ran the model in torch mobile. It is perfectly running in native app. But, when module was adopted quantization, return below error message.That’s a reasonable error.

import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile

class conv_with_conv(nn.Module):

    def __init__(self,input_dim,in_ch,out_ch,kernel_size,stride):
        super().__init__()
        self.shift_size = input_dim
        self.register_buffer('internal_state',torch.zeros(kernel_size-stride+input_dim,in_ch))
        self.conv = nn.Conv2d(in_ch,out_ch,(kernel_size,1),stride)

    def forward(self,x):

        self.internal_state.data.copy_(torch.roll(self.internal_state,-self.shift_size))
        self.internal_state[-self.shift_size:].data.copy_(x.squeeze(-1).squeeze(0))
        x = self.internal_state.unsqueeze(0).transpose(1,2).unsqueeze(-1)
        x = self.conv(x)

        return x

class test_model(nn.Module):

    def __init__(self,):
        super().__init__()

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        self.conv = conv_with_conv(2,40,1,3,1)

    def forward(self,x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x

model = test_model()
model.eval()
model.to('cpu')
dumy_input = torch.rand(1,2,40,1)
out = model(dumy_input)
print('float out',out)

model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare(model,inplace=True)
torch.quantization.convert(model,inplace=True)
jit_model = torch.jit.script(model)
torchscript_model_optimized = optimize_for_mobile(jit_model)
torchscript_model_optimized._save_for_lite_interpreter("test.ptl")

out = model(dumy_input)
print('quant out',out)

Error message of frist code

float out tensor([[[[-0.0496],
          [-0.1337]]]], grad_fn=<ThnnConv2DBackward>)
/home/hwlee/.local/lib/python3.8/site-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/home/hwlee/.local/lib/python3.8/site-packages/torch/quantization/observer.py:243: UserWarning: must run observer before calling calculate_qparams.                                        Returning default scale and zero point
  warnings.warn(
Traceback (most recent call last):
  File "conv_with_state.py", line 57, in <module>
    out = model(dumy_input)
  File "/home/hwlee/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "conv_with_state.py", line 35, in forward
    x = self.conv(x)
  File "/home/hwlee/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "conv_with_state.py", line 16, in forward
    self.internal_state[-self.shift_size:].data.copy_(x.squeeze(-1).squeeze(0))
RuntimeError: Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor

Second

So, I apply QuanStub to tensor of register_buffer. But, I got type error about copying from quantized tensor to non-quan tensor.

import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile

class conv_with_conv(nn.Module):

    def __init__(self,input_dim,in_ch,out_ch,kernel_size,stride):
        super().__init__()
        self.shift_size = input_dim
        self.quant = torch.quantization.QuantStub()
        self.register_buffer('internal_state',self.quant(torch.zeros(kernel_size-stride+input_dim,in_ch)))
        self.conv = nn.Conv2d(in_ch,out_ch,(kernel_size,1),stride)

    def forward(self,x):

        self.internal_state.data.copy_(torch.roll(self.internal_state,-self.shift_size))
        self.internal_state[-self.shift_size:].data.copy_(x.squeeze(-1).squeeze(0))
        x = self.internal_state.unsqueeze(0).transpose(1,2).unsqueeze(-1)
        x = self.conv(x)

        return x

class test_model(nn.Module):

    def __init__(self,):
        super().__init__()

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        self.conv = conv_with_conv(2,40,1,3,1)

    def forward(self,x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x

model = test_model()
model.eval()
model.to('cpu')
dumy_input = torch.rand(1,2,40,1)
out = model(dumy_input)
print('float out',out)

model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare(model,inplace=True)
torch.quantization.convert(model,inplace=True)
jit_model = torch.jit.script(model)
torchscript_model_optimized = optimize_for_mobile(jit_model)
torchscript_model_optimized._save_for_lite_interpreter("test.ptl")

out = model(dumy_input)
print('quant out',out)

Error message of second code

float out tensor([[[[-0.1726],
          [-0.2154]]]], grad_fn=<ThnnConv2DBackward>)
/home/hwlee/.local/lib/python3.8/site-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/home/hwlee/.local/lib/python3.8/site-packages/torch/quantization/observer.py:243: UserWarning: must run observer before calling calculate_qparams.                                        Returning default scale and zero point
  warnings.warn(
Traceback (most recent call last):
  File "conv_with_state2.py", line 55, in <module>
    out = model(dumy_input)
  File "/home/hwlee/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "conv_with_state2.py", line 36, in forward
    x = self.conv(x)
  File "/home/hwlee/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "conv_with_state2.py", line 17, in forward
    self.internal_state[-self.shift_size:].data.copy_(x.squeeze(-1).squeeze(0))
RuntimeError: Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor

Last try

I made qint8 tensor for register_buffer tensor. And replace torch.roll() to just copy_() and clone(). Because torch.roll() can’t operate in quantization model. Addtionaly, I didn’t call forward() before quantization.convert(). It must occur type error of cat() for qint8 tensor and float tensor.
Despite these change, Cat() fucntion is still return “dtype isn’t same”. Also, I tried print dtype of tensor after quantized convert. But is print nothing.

import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile

class conv_with_conv(nn.Module):

    def __init__(self,input_dim,in_ch,out_ch,kernel_size,stride):
        super().__init__()
        self.shift_size = input_dim

        scale,zero_point = 1e-4,2
        dtype = torch.qint8
        float_state = torch.zeros(kernel_size-stride+input_dim,in_ch)
        int_state = torch.quantize_per_tensor(float_state,scale,zero_point,dtype)

        self.f_cat = nn.quantized.FloatFunctional()

        self.register_buffer('internal_state',int_state)
        self.conv = nn.Conv2d(in_ch,out_ch,(kernel_size,1),stride)

    def forward(self,x):

        self.internal_state[:self.shift_size].data.copy_(self.internal_state[self.shift_size:].clone())
        x = self.f_cat.cat((self.internal_state[:-self.shift_size].unsqueeze(0).unsqueeze(-1).clone(),x))

        x = self.internal_state.unsqueeze(0).transpose(1,2).unsqueeze(-1)
        x = self.conv(x)

        return x


class test_model(nn.Module):

    def __init__(self,):
        super().__init__()

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        self.conv = conv_with_conv(2,40,1,3,1)

    def forward(self,x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x


model = test_model()
model.eval()
model.to('cpu')
dumy_input = torch.rand(1,2,40,1)

model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare(model,inplace=True)
torch.quantization.convert(model,inplace=True)
jit_model = torch.jit.script(model)
torchscript_model_optimized = optimize_for_mobile(jit_model)
torchscript_model_optimized._save_for_lite_interpreter("test.ptl")

out = model(dumy_input)
print('quant out',out)

Error message of last code

/home/hwlee/.local/lib/python3.8/site-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/home/hwlee/.local/lib/python3.8/site-packages/torch/quantization/observer.py:243: UserWarning: must run observer before calling calculate_qparams.                                        Returning default scale and zero point
  warnings.warn(
Traceback (most recent call last):
  File "conv_with_state3.py", line 61, in <module>
    out = model(dumy_input)
  File "/home/hwlee/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "conv_with_state3.py", line 44, in forward
    x = self.conv(x)
  File "/home/hwlee/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "conv_with_state3.py", line 24, in forward
    x = self.f_cat.cat((self.internal_state[:-self.shift_size].unsqueeze(0).unsqueeze(-1).clone(),x))
  File "/home/hwlee/.local/lib/python3.8/site-packages/torch/nn/quantized/modules/functional_modules.py", line 212, in cat
    r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim)
RuntimeError: All dtypes must be the same.

Please let me know if I’m doing something wrong or if there are other possible attempts.

1 Like

Here is a code snippet of your example modified with print statements of the dtypes of inputs to torch.cat: gist:a9dbaf4d2e69d8289aa7d21d6d7e4e3c · GitHub

Looks like currently you are trying to concatenate a tensor with dtype torch.qint8 with a tensor with dtype torch.quint8. If you make those dtypes match, it will work. Hope this helps.