Adding custom quantized C++ op

Hi,

I am following the example in the following file that creates “reference quantized representations” of different corresponding floating point ops.

From what I understand, it seems like it is trying to match “dq → float_op → q” and replace it with some reference quantized op.

What I would like to do is to provide my own C++ implementation for the following function in rewrite.py.

def _reference_quantized_linear(
    x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
    weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
    bias_fp32,
    out_scale, out_zero_point, out_quant_min, out_quant_max
):

The question I grappling with is how to setup the signature of the function. There seems to two choices:

Use torch::Tensor for all arguments (including scalar arguments like scales and zero points). For example:

    m.def("fc_relu_s8u8s8(Tensor x_i8, Tensor x_scale, Tensor xzp, Tensor xmin, Tensor xmax, Tensor w_i8, Tensor w_scale, Tensor wzp, Tensor wmin, Tensor wmax, Tensor bfp32, Tensor y_scale, Tensor yzp, Tensor ymin, Tensor ymax) -> Tensor");

If I use the above representation, I can replace the pattern _qdq_quantized_linear with fc_relu_s8u8s8 in the graph, but when I run the graph, it fails because it cannot cast scalar values (like x_scale, x_zp etc., to Tensors when calling fc_relu_s8u8s8).

Or, I could use scalar arguments for scales and zero points as below:

    m.def("fc_prelu_s8u8s8_raw(Tensor x_i8, float xs, int xzp, int xmin, int xmax, Tensor w_i8, double ws, int wzp, int wmin,int wmax, Tensor bfp32, doublw ys, int yzp, int ymin, int ymax) -> Tensor");

If I use this approach the Subgraph rewriter complains about “dead code” (which I guess is related to the scalar arguments).

Is there a way around this, or is there some reference code available that shows how to do such a thing? I am stuck between a rock and a hard place.

Appreciate your help or insights on this. Please let me know if I can provide more information. I am using the recent torch from github.

In [5]: torch.version.__version__
Out[5]: '2.2.0a0+gitb4ce501'

Thanks,
Vijay.

OK. Found the problem. I defined my op using the actual types for each parameter:

m.def("fc_prelu_s8u8s8_raw(Tensor x_i8, float xs, int xzp, int xmin, int xmax, Tensor w_i8, double ws, int wzp, int wmin,int wmax, Tensor bfp32, double ys, int yzp, int ymin, int ymax) -> Tensor");

Then, I followed the instructions in this document The C++ Custom Operators Manual - Google Docs to bind the OP to torch ( as torch.op.<my_lib>.fc_prelu_s8u8s8 ). This doc should really be a part of the official PyTorch docs!

Then, instead of get_aten_graph_module which internally calls capture_pre_autograd_graph (which I am guessing is equivalent to torch.export.export that exports a FX graph by passing real inputs to the module, I used torch.fx.symbolic_trace to convert the wrapper function that wraps torch.ops.<my_lib>.fc_prelu_s8u8s8) into a FX graph. With this approach, I don’t see any warning about literals or dead code and I can able to rewrite the quant -> linear -> dequant combination with my new fixed point Op.

        # replacement = get_aten_graph_module(replacement, example_inputs)  # type: ignore[arg-type, assignment]
        replacement = torch.fx.symbolic_trace(replacement)