Hello, I’ve been experimenting with torchscript and dynamic quantization and often have the issue that results of the models that are dynamically quantized are not consistent between Python and Java.
To reproduce the issue I created a fork of the python java-demo: GitHub - westphal-jan/java-demo.
To setup you need to download libtorch and set its location in build.gradle
(line 16).
Download Link: https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.1%2Bcpu.zip
I created a simple dummy model with one linear layer and export it unquantized and quantized here: create_dummy_models.py
(The code can also be run using the dependencies defined in requirements.txt
but I also commited the dummy models)
Python:
Unquantized model:
[[-2.758167028427124, 2.0038578510284424, -4.114053726196289, -1.2928203344345093, 1.4940322637557983]]
Quantized model:
[[-2.747678756713867, 1.9912285804748535, -4.110795021057129, -1.2891944646835327, 1.4982664585113525]]
You can run the java code with ./gradlew run
.
Java:
Unquantized model:
data: [-2.758167, 2.0038579, -4.1140537, -1.2928203, 1.4940323]
[W qlinear_dynamic.cpp:239] Warning: Currently, qnnpack incorrectly ignores reduce_range when it is set to true; this may change in a future release. (function apply_dynamic_impl)
Quantized model:
data: [-2.7473624, 1.9966378, -4.110954, -1.283469, 1.4918814]
As you can see the output of the unquantized model is perfectly consistent while the output of the dynamically quantized model is slightly inconsistent. It might seem insignificant but with larger models like a transformer it becomes more obvious (differences usually already in the first decimal place). Am I misunderstanding something conceptually?
I thought as the code is compiled down to C++ and both examples run on the same architecture (CPU, x86_64) it should produce the same output even when using dynamic quantization (the activations are computed on the fly but they should still be deterministic).
Note: I made sure that Python and Java use the same version of Torch 1.13.1
which is the latest published mvn version (mvnrepository → org.pytorch/pytorch_java_only)