Unable to run custom quantized module with custom Tensor class in convert_fx/convert_to_reference_fx

There are two problems when I want to run torch cuda int8 inference with custom int8 layers:

  1. convert_fx don’t provide any customization for nni to nniq conversion (which is defined in STATIC_LOWER_FUSED_MODULE_MAP in _lower_to_native_backend.py). I need to modify this global value to convert custom fusion layers.
  2. quantized modules only support torch.Tensor, my custom Tensor class isn’t working, the quantized module generates following code:
GraphModule(
  (net): Module(
    (0): Module(
      (0): QuantizedSparseConvReLU(1, 32, kernel_size=[3, 3], stride=[1, 1], scale=0.03292962536215782, zero_point=0, padding=[1, 1], dilation=[1, 1], output_padding=[0, 0], wqscheme=torch.per_channel_affine)
    )
    (1): Module(
      (0): QuantizedSparseConvReLU(32, 64, kernel_size=[3, 3], stride=[1, 1], scale=0.037994351238012314, zero_point=0, padding=[1, 1], dilation=[1, 1], output_padding=[0, 0], wqscheme=torch.per_channel_affine)
    )
    (2): Module(
      (0): QuantizedSparseConvReLU(64, 64, kernel_size=[2, 2], stride=[2, 2], scale=0.038743481040000916, zero_point=0, padding=[0, 0], dilation=[1, 1], output_padding=[0, 0], wqscheme=torch.per_channel_affine)
    )
    (3): Module(
      (0): QuantizedSparseConvReLU(64, 64, kernel_size=[2, 2], stride=[2, 2], scale=0.05028770491480827, zero_point=0, padding=[0, 0], dilation=[1, 1], output_padding=[0, 0], wqscheme=torch.per_channel_affine)
    )
    (4): Module(
      (0): QuantizedSparseConvReLU(64, 64, kernel_size=[3, 3], stride=[2, 2], scale=0.05744420737028122, zero_point=0, padding=[1, 1], dilation=[1, 1], output_padding=[0, 0], wqscheme=torch.per_channel_affine)
    )
    (5): QuantizedSparseConv(Reference)(64, 10, kernel_size=[4, 4], stride=[4, 4], padding=[0, 0], dilation=[1, 1], output_padding=[0, 0], algo=ConvAlgo.MaskImplicitGemm)
  )
)



def forward(self, features : torch.Tensor, indices : torch.Tensor, batch_size : int):
    sparse_conv_tensor = spconv_pytorch_core_SparseConvTensor(features, indices, [28, 28], batch_size);  features = indices = batch_size = None
    _scale_0 = self._scale_0
    _zero_point_0 = self._zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(sparse_conv_tensor, _scale_0, _zero_point_0, torch.qint8);  sparse_conv_tensor = _scale_0 = _zero_point_0 = None
    net_0_0 = getattr(getattr(self.net, "0"), "0")(quantize_per_tensor);  quantize_per_tensor = None
    net_1_0 = getattr(getattr(self.net, "1"), "0")(net_0_0);  net_0_0 = None
    net_2_0 = getattr(getattr(self.net, "2"), "0")(net_1_0);  net_1_0 = None
    net_3_0 = getattr(getattr(self.net, "3"), "0")(net_2_0);  net_2_0 = None
    net_4_0 = getattr(getattr(self.net, "4"), "0")(net_3_0);  net_3_0 = None
    dequantize_5 = net_4_0.dequantize();  net_4_0 = None
    net_5 = getattr(self.net, "5")(dequantize_5);  dequantize_5 = None
    net_5_scale_0 = self.net_5_scale_0
    net_5_zero_point_0 = self.net_5_zero_point_0
    quantize_per_tensor_6 = torch.quantize_per_tensor(net_5, net_5_scale_0, net_5_zero_point_0, torch.qint8);  net_5 = net_5_scale_0 = net_5_zero_point_0 = None
    dequantize_6 = quantize_per_tensor_6.dequantize();  quantize_per_tensor_6 = None
    dense = dequantize_6.dense();  dequantize_6 = None
    net_6_scale_0 = self.net_6_scale_0
    net_6_zero_point_0 = self.net_6_zero_point_0
    quantize_per_tensor_7 = torch.quantize_per_tensor(dense, net_6_scale_0, net_6_zero_point_0, torch.qint8);  dense = net_6_scale_0 = net_6_zero_point_0 = None
    flatten = torch.flatten(quantize_per_tensor_7, 1);  quantize_per_tensor_7 = None
    dequantize_8 = flatten.dequantize();  flatten = None
    log_softmax = torch.nn.functional.log_softmax(dequantize_8, dim = 1, _stacklevel = 3, dtype = None);  dequantize_8 = None
    return log_softmax

torch.quantize_per_tensor doesn’t accept my SparseConvTensor, the correct code (for this custom class) should be

quantize_per_tensor = sparse_conv_tensor.replace_feature(torch.quantize_per_tensor(sparse_conv_tensor.features, _scale_0, _zero_point_0, torch.qint8))

How to solve these problems?

Edit:
problem 2 is solved by simple fx graph transform:

def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in m.graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.quantize_per_tensor:
                node.target = custom_quantize_per_tensor

    m.graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.
    m.recompile()
    return m

cc @jerryzh168 for request to expose the mapping from lowering code

for problem 1, please use convert_to_reference_fx if you are doing customizations, feel free to copy paste our lowering code as a starting point, convert_fx means convert for native pytorch backend

I’m working on tensorrt int8 deploy. all int8 layers are implemented as tensorrt plugins.
“convert_to_reference_fx” only convert quantized module to reference, it doesn’t:

  1. perform QDQ merge
  2. convert module to custom backend (CUDA for me) to test int8 in pytorch

Tensorrt won’t fuse QDQ for any int8 plugins. so If I use “convert_to_reference_fx” instead, I need to reimplement all QDQ removal code in “_lower_to_native_backend”. In addition, it’s important to test int8 model in pytorch instead of test in tensorrt directly for debug propose.

In my deploy work, I need to do QDQ merge by “convert_fx” for all custom int8 plugins, but I can’t do QDQ merge for regular layers such as Linear and Conv to make tensorrt explicit quantization works correctly, so I

  1. add custom layer lowering dict to STATIC_LOWER_FUSED_MODULE_MAP
  2. remove linear from STATIC_LOWER_FUSED_MODULE_MAP

to achieve this.

Edit:
The most important thing for custom backend is expose “qconfig_map” in _lower_to_native_backend to users. I can implement all backend code if “qconfig_map” is available.

sounds good, would copy-pasting _lower_to_native_backend work for you? I don’t think we want to expose this function as a public api since it’s unclear how useful that would be because lowering for different backends tends to be different, see possible lowering options in rfcs/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md at master · pytorch/rfcs · GitHub

Currently if I want to implement a new backend, I need not only _lower_to_native_backend, but also
qconfig_map and node_name_to_scope (only used for packed weights, tensorrt don’t need this) which generated by some private functions in torch.ao.quantization.fx.convert. the custom backend RFC seems doesn’t contains any apis for qconfig_map and node_name_to_scope

Quick Workaround

you can get these things from the prepared model,

prepared_model = prepare_fx(model, ...)
qconfig_map = prepared_model._qconfig_map
node_name_to_scope = prepared_model._node_name_to_scope

although these are private.

Another probably more robust solution

I think a more robust solutions would be:
(1). qconfig_map is not very important, it is just used to skip lowering for some patterns, maybe you can just remove this argument and all relevant code after you copy _lower_to_native_backend
(2). node_name_to_scope: a more robust way to get it would be to trace the original model again with QuantizationTracer:pytorch/quantize_fx.py at master · pytorch/pytorch · GitHub and get the node_name_to_scope from there, e.g.

tracer = QuantizationTracer(...)
tracer.trace(original_model)
node_name_to_scope = tracer.node_name_to_scope

this might be slightly more robust than the previous solution, but we may change it later as well, since these are all implementation details of the quantization flow