Quantization Error During Concat -- RuntimeError: Didn't find kernel to dispatch to for operator 'aten::_cat'

During static quantization of my model, I encounter the following error -

RuntimeError: Didn’t find kernel to dispatch to for operator ‘aten::_cat’. Tried to look up kernel for dispatch key ‘QuantizedCPUTensorId’. Registered dispatch keys are: [CPUTensorId, VariableTensorId]

I have fused and quantized the model, as well as the input image. But it throws an error on concat raised by - y = torch.cat([sources[0], sources[1]], dim=1)

Any suggestions would be appreciated. :slight_smile:
Full code here -

Please see the usage of skip_add (+= operation) here: https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html#model-architecture

The operators listed here https://pytorch.org/docs/stable/quantization.html#torch.nn.quantized.QFunctional should be replaced with their functional module counterpart in the network before post-training quantization.

@dskhudia thank you for your suggestion. :slight_smile:

I replaced the ‘cat’ modules with n.quantized.FloatFunctional().cat()

But I run into another error -
TypeError: NumPy conversion for Variable[QuantizedCPUQUInt8Type] is not supported

from the line y[0,:,:,0].cpu().data.numpy()
(line 52 in the link above)

How do I convert this quantized variable to numpy ? Thanks again for your help.

There are a couple of options depending on what you want:

If you want quantized integer data, use int_repr (https://github.com/pytorch/pytorch/wiki/Introducing-Quantized-Tensor)

If you want float data, dequantize and use your existing way of converting it to numpy.

@dskhudia
Thank you very much, I changed
score_link = y[0,:,:,1].cpu().data.numpy() to score_link = y[0,:,:,1].int_repr().cpu().data.numpy() as per your suggestion. But the prediction is very bad.

Can you point me to how to dequantize the model ?

This is the final prediction. Correct? If yes, you would need to dequantize the final tensor, .e.g, using dequantized_y = y.dequantize()

thanks a lot @dskhudia
I tried both methods. Final prediction is really bad :frowning:

Original Prediction from FP32 model -
image

Prediction from INT8 model -
image

Not sure where I’m going wrong :confused:

You may want to try some quantization accuracy improvement techniques such as

per channel quantization for weights
Quantization aware training
Measuring torch.norm between float model and quantize model to see where it’s off the most.

is there an example for per channel quantization and measuring the torch norm between the 2 models ?

For per channel see https://github.com/pytorch/tutorials/blob/master/advanced_source/static_quantization_tutorial.py

and for norm you can use something like the following:

SQNR = []
for i in range(len(ref_output)):
   
    SQNR.append(20*torch.log10(torch.norm(ref_output[i][0])/torch.norm(ref_output[i][0]-qtz_output[i][0])).numpy())

print('SQNR (dB)', SQNR)

@dskhudia
The performance improved slightly after per channel quantization, but it is still very bad

image

Do you think I should try float 16 instead? If so, how do I change the config to change it to Float16.

Also, in your earlier response ref_output is the output from the net ? i.e ref_output=net(x), is that what you meant ?

Thanks again for your help , hope I can resolve this problem

Float16 quantized operators do not exist for static quantization. Since current cpus do not support float16 compute natively, converting to float16 for compute bound cases doesn’t provide much performance benefits.

ref_output is from the float model. You might want to check the norm at few different places in the network to see where we are deviating too much from floating point results.

In PyTorch there’s a way to compare the module level quantization error, which could help to debug and narrow down the issue. I’m working on an example and will post here later.

1 Like

@Raghav_Gurbaxani, have you tried using histogram observer for activation? In most cases this could improve the accuracy of the quantized model. You can do:
model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.default_histogram_observer,
weight=torch.quantization.default_per_channel_weight_observer)

thanks @hx89 , if you could post that example for compare module level quantization error - It would be great :smiley:

In the meantime, I tried the histogram observer and the result is still pretty bad :confused:
image

any other suggestions ?

Have you checked the accuracy of fused_model? By checking the accuracy of fused_model before converting to int8 model we can know if the issue is in the preprocessing part or in the quantized model.

If fused_model has good accuracy, the next step we can check the quantization error of the weights. Could you try the following code:

def l2_error(ref_tensor, new_tensor):
    """Compute the l2 error between two tensors.

    Args:
        ref_tensor (numpy array): Reference tensor.
        new_tensor (numpy array): New tensor to compare with.

    Returns:
        abs_error: l2 error
        relative_error: relative l2 error
    """
    assert (
        ref_tensor.shape == new_tensor.shape
    ), "The shape between two tensors is different"

    diff = new_tensor - ref_tensor
    abs_error = np.linalg.norm(diff)
    ref_norm = np.linalg.norm(ref_tensor)
    if ref_norm == 0:
        if np.allclose(ref_tensor, new_tensor):
            relative_error = 0
        else:
            relative_error = np.inf
    else:
        relative_error = np.linalg.norm(diff) / ref_norm
    return abs_error, relative_error

float_model_dbg = fused_model
qmodel_dbg = quantized

for key in float_model_dbg.state_dict().keys():
    float_w = float_model_dbg.state_dict()[key]
    qkey = key
    
    # Get rid of extra hiearchy of the fused Conv in float model
    if key.endswith('.weight'):
        qkey = key[:-9] + key[-7:] 

    if qkey in qmodel_dbg.state_dict():
        q_w = qmodel_dbg.state_dict()[qkey]
        if q_w.dtype == torch.float:
            abs_error, relative_error = l2_error(float_w.numpy(), q_w.detach().numpy())
        else:
            abs_error, relative_error = l2_error(float_w.numpy(), q_w.dequantize().numpy())
        print(key, ', abs error = ', abs_error, ", relative error = ", relative_error)

It should print out the quantization error for each Conv weight such as:

features.0.0.weight , abs error =  0.21341866 , relative error =  0.01703797
features.3.squeeze.0.weight , abs error =  0.095942035 , relative error =  0.012483358
features.3.expand1x1.0.weight , abs error =  0.071949296 , relative error =  0.010309489
features.3.expand3x3.0.weight , abs error =  0.18284422 , relative error =  0.025256516
features.4.squeeze.0.weight , abs error =  0.088713735 , relative error =  0.011313644
features.4.expand1x1.0.weight , abs error =  0.0780085 , relative error =  0.0126931975
...

@hx89 the performance of the fused model is good

image

That means there’s something wrong on the quantization side, not the fusion side. :confused:

Here’s the log of the relative norm errors -
https://github.com/raghavgurbaxani/experiments/blob/master/quantization_error.txt

Can you suggest what to do next ? Is there any way to reduce these errors ? Apart from QAT ofcourse

Looks like the first Conv basenet.slice1.3.0.weight has the largest error, could you try skipping the quantization of that Conv and keep it as the float module? We have previously seen some CV models’s first Conv is sensitive to quantization and skipping it would give better accuracy.

@hx89 actually it seems like all these have pretty high relative errors -
[ basenet.slice1.7.0.weight , basenet.slice1.10.0.weight , basenet.slice2.14.0.weight , basenet.slice2.17.0.weight , basenet.slice3.20.0.weight ,basenet.slice3.24.0.weight , basenet.slice3.27.0.weight , basenet.slice4.30.0.weight ,basenet.slice4.34.0.weight ]

although that seems like a good idea, keeping a few layers as float while converting the rest to int8.

I am not sure how to pass the partial model to torch.quantization.convert() for quantization and then combining the partially quantized model and unquantized layers together for inference on the image.

Could you provide an example ? Thanks a ton

It’s actually simpler, to skip the first conv for example, there are two step:

Step 1: Move the quant stub after the first conv in the forward function of the module.

For example in the original quantizable module, quant stub is at the beginning before conv1:

Class QuantizableNet(nn.Module):
    def __init__(self):
        ...
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

To skip the quantization of conv1 we can move self.quant() aftert conv1:

Class QuantizableNet(nn.Module):
    def __init__(self):
        ...
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.conv1(x)
        x = self.quant(x)
        x = self.maxpool(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

Step 2: Then we need to set the qconfig of conv1 to None after prepare(), this way PyTorch knows we want to keep conv1 as float module and won’t swap it with quantized module:

model = QuantizableNet()
...
torch.quantization.prepare(model)
model.conv1.qconfig = None
1 Like