Hi,
I am following the example in the following file that creates “reference quantized representations” of different corresponding floating point ops.
From what I understand, it seems like it is trying to match “dq → float_op → q” and replace it with some reference quantized op.
What I would like to do is to provide my own C++ implementation for the following function in rewrite.py
.
def _reference_quantized_linear(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
bias_fp32,
out_scale, out_zero_point, out_quant_min, out_quant_max
):
The question I grappling with is how to setup the signature of the function. There seems to two choices:
Use torch::Tensor
for all arguments (including scalar arguments like scales and zero points). For example:
m.def("fc_relu_s8u8s8(Tensor x_i8, Tensor x_scale, Tensor xzp, Tensor xmin, Tensor xmax, Tensor w_i8, Tensor w_scale, Tensor wzp, Tensor wmin, Tensor wmax, Tensor bfp32, Tensor y_scale, Tensor yzp, Tensor ymin, Tensor ymax) -> Tensor");
If I use the above representation, I can replace the pattern _qdq_quantized_linear
with fc_relu_s8u8s8
in the graph, but when I run the graph, it fails because it cannot cast scalar values (like x_scale
, x_zp
etc., to Tensors when calling fc_relu_s8u8s8
).
Or, I could use scalar arguments for scales and zero points as below:
m.def("fc_prelu_s8u8s8_raw(Tensor x_i8, float xs, int xzp, int xmin, int xmax, Tensor w_i8, double ws, int wzp, int wmin,int wmax, Tensor bfp32, doublw ys, int yzp, int ymin, int ymax) -> Tensor");
If I use this approach the Subgraph rewriter complains about “dead code” (which I guess is related to the scalar arguments).
Is there a way around this, or is there some reference code available that shows how to do such a thing? I am stuck between a rock and a hard place.
Appreciate your help or insights on this. Please let me know if I can provide more information. I am using the recent torch from github.
In [5]: torch.version.__version__
Out[5]: '2.2.0a0+gitb4ce501'
Thanks,
Vijay.