I found that in some case, the pytorch result of qat linear will be error. Next I will describe it.
I insert some print in model forword code
print(layers_3_out)
layers_4_out = layers_4(layers_3_out)
print(layers_4._packed_params)
print(layers_4_out)
The layers_4 is define in init:
self.layers_4 = nn.Linear(in_features=16, out_features=1, bias=True)
So as above shows, I print the input and linear params and output, and the result is as follows:
tensor([[5.0954, 1.2068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.6091,
0.0000, 1.3409, 2.6818, 4.2908, 2.1454, 2.9500, 1.8772],
[5.0954, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.8772],
[5.0954, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.8772],
[5.0954, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.8772],
[5.0954, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.8772],
[5.0954, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.8772],
[5.2295, 1.4750, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.3409,
0.0000, 1.4750, 3.0840, 4.4249, 2.0113, 2.9500, 1.7432],
[5.0954, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.8772],
[4.9613, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.7432],
[4.9613, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.7432],
[4.9613, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.7432],
[4.6931, 1.2068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.4136, 3.8886, 1.8772, 2.8159, 1.7432],
[4.6931, 1.2068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.4136, 3.8886, 1.8772, 2.8159, 1.7432],
[4.6931, 1.2068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.4136, 3.8886, 1.8772, 2.8159, 1.7432],
[4.6931, 1.2068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.4136, 3.8886, 1.8772, 2.8159, 1.7432],
[4.4249, 1.0727, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.2068, 2.2795, 3.6204, 1.8772, 2.8159, 1.7432],
[4.4249, 1.0727, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.2068, 2.1454, 3.7545, 1.7432, 2.8159, 1.7432],
[4.5590, 1.2068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.2068, 2.2795, 3.7545, 1.7432, 2.8159, 1.7432],
[4.4249, 1.0727, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.2068, 2.1454, 3.7545, 1.7432, 2.8159, 1.7432],
[4.4249, 1.0727, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.2068, 2.1454, 3.7545, 1.7432, 2.8159, 1.7432],
[4.4249, 1.0727, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.2068, 2.1454, 3.7545, 1.7432, 2.8159, 1.7432],
[4.6931, 1.2068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.4136, 4.0227, 1.7432, 2.8159, 1.7432],
[4.4249, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.2068,
0.0000, 1.3409, 2.4136, 3.8886, 1.7432, 2.6818, 1.7432],
[4.5590, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.2068,
0.0000, 1.3409, 2.4136, 3.8886, 1.7432, 2.8159, 1.6091],
[4.4249, 1.0727, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.2068, 2.1454, 3.7545, 1.7432, 2.8159, 1.7432],
[4.6931, 1.2068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.4136, 4.0227, 1.7432, 2.8159, 1.7432],
[4.6931, 1.2068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.4136, 3.8886, 1.8772, 2.8159, 1.7432],
[4.9613, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.7432],
[5.0954, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.8772],
[5.0954, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.8772],
[5.3636, 1.4750, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.4750, 3.0840, 4.5590, 2.1454, 2.9500, 1.8772],
[5.0954, 1.3409, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4750,
0.0000, 1.3409, 2.6818, 4.2908, 2.0113, 2.9500, 1.8772]],
size=(32, 16), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=0.13408887386322021,
zero_point=173)
(tensor([[-0.5172, -0.6572, 0.2640, 0.4471, 0.2855, -0.5549, 0.4741, 0.4364,
-0.5118, -0.0162, -0.3448, -0.5172, -0.6896, -0.3448, -0.6465, -0.4902]],
size=(1, 16), dtype=torch.qint8,
quantization_scheme=torch.per_tensor_affine, scale=0.005387257784605026,
zero_point=0), tensor([-0.1708], requires_grad=True))
tensor([[3.7054],
[3.8143],
[3.8143],
[3.8143],
[3.8143],
[3.8143],
[3.5964],
[3.8143],
[3.8143],
[3.8143],
[3.8143],
[3.9233],
[3.9233],
[3.9233],
[3.9233],
[4.0323],
[4.1413],
[4.0323],
[4.1413],
[4.1413],
[4.1413],
[3.9233],
[4.1413],
[4.1413],
[4.1413],
[3.9233],
[3.9233],
[3.8143],
[3.8143],
[3.8143],
[3.5964],
[3.8143]], size=(32, 1), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=0.10898102819919586,
zero_point=186)
Then I calculate the first result of this linear, as
input | 5.0954 | 1.2068 | 0 | 0 | 0 | 0 | 0 | 0 | 1.6091 | 0 | 1.3409 | 2.6818 | 4.2908 | 2.1454 | 2.95 | 1.8772 | bias |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
weight | -0.5172 | -0.6572 | 0.264 | 0.4471 | 0.2855 | -0.5549 | 0.4741 | 0.4364 | -0.5118 | -0.0162 | -0.3448 | -0.5172 | -0.6896 | -0.3448 | -0.6465 | -0.4902 | -0.1708 |
mul | -2.63534088 | -0.79310896 | 0 | 0 | 0 | 0 | 0 | 0 | -0.82353738 | 0 | -0.46234232 | -1.38702696 | -2.95893568 | -0.73973392 | -1.907175 | -0.92020344 | -12.79820454 |
And I found that the result should be -12.79820454, but the result from pytorch is 3.7054, it is error.
And before inference, I qat and prepare model as before:
net.train()
net.qconfig = torch.quantization.get_default_qat_qconfig(“qnnpack”)
net.fuse_modules()
torch.quantization.prepare_qat(net, inplace=True)
net.load_state_dict(new_state_dict, strict=True)
net.eval()
net.apply(torch.quantization.disable_observer)
net = torch.quantization.convert(net)
And In this model, other linear is calculated correct, except this one. So if anyone meet this problem before, Can anyone help me?