There are two problems when I want to run torch cuda int8 inference with custom int8 layers:

- convert_fx don’t provide any customization for nni to nniq conversion (which is defined in STATIC_LOWER_FUSED_MODULE_MAP in _lower_to_native_backend.py). I need to modify this global value to convert custom fusion layers.
- quantized modules only support torch.Tensor, my custom Tensor class isn’t working, the quantized module generates following code:

```
GraphModule(
(net): Module(
(0): Module(
(0): QuantizedSparseConvReLU(1, 32, kernel_size=[3, 3], stride=[1, 1], scale=0.03292962536215782, zero_point=0, padding=[1, 1], dilation=[1, 1], output_padding=[0, 0], wqscheme=torch.per_channel_affine)
)
(1): Module(
(0): QuantizedSparseConvReLU(32, 64, kernel_size=[3, 3], stride=[1, 1], scale=0.037994351238012314, zero_point=0, padding=[1, 1], dilation=[1, 1], output_padding=[0, 0], wqscheme=torch.per_channel_affine)
)
(2): Module(
(0): QuantizedSparseConvReLU(64, 64, kernel_size=[2, 2], stride=[2, 2], scale=0.038743481040000916, zero_point=0, padding=[0, 0], dilation=[1, 1], output_padding=[0, 0], wqscheme=torch.per_channel_affine)
)
(3): Module(
(0): QuantizedSparseConvReLU(64, 64, kernel_size=[2, 2], stride=[2, 2], scale=0.05028770491480827, zero_point=0, padding=[0, 0], dilation=[1, 1], output_padding=[0, 0], wqscheme=torch.per_channel_affine)
)
(4): Module(
(0): QuantizedSparseConvReLU(64, 64, kernel_size=[3, 3], stride=[2, 2], scale=0.05744420737028122, zero_point=0, padding=[1, 1], dilation=[1, 1], output_padding=[0, 0], wqscheme=torch.per_channel_affine)
)
(5): QuantizedSparseConv(Reference)(64, 10, kernel_size=[4, 4], stride=[4, 4], padding=[0, 0], dilation=[1, 1], output_padding=[0, 0], algo=ConvAlgo.MaskImplicitGemm)
)
)
def forward(self, features : torch.Tensor, indices : torch.Tensor, batch_size : int):
sparse_conv_tensor = spconv_pytorch_core_SparseConvTensor(features, indices, [28, 28], batch_size); features = indices = batch_size = None
_scale_0 = self._scale_0
_zero_point_0 = self._zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(sparse_conv_tensor, _scale_0, _zero_point_0, torch.qint8); sparse_conv_tensor = _scale_0 = _zero_point_0 = None
net_0_0 = getattr(getattr(self.net, "0"), "0")(quantize_per_tensor); quantize_per_tensor = None
net_1_0 = getattr(getattr(self.net, "1"), "0")(net_0_0); net_0_0 = None
net_2_0 = getattr(getattr(self.net, "2"), "0")(net_1_0); net_1_0 = None
net_3_0 = getattr(getattr(self.net, "3"), "0")(net_2_0); net_2_0 = None
net_4_0 = getattr(getattr(self.net, "4"), "0")(net_3_0); net_3_0 = None
dequantize_5 = net_4_0.dequantize(); net_4_0 = None
net_5 = getattr(self.net, "5")(dequantize_5); dequantize_5 = None
net_5_scale_0 = self.net_5_scale_0
net_5_zero_point_0 = self.net_5_zero_point_0
quantize_per_tensor_6 = torch.quantize_per_tensor(net_5, net_5_scale_0, net_5_zero_point_0, torch.qint8); net_5 = net_5_scale_0 = net_5_zero_point_0 = None
dequantize_6 = quantize_per_tensor_6.dequantize(); quantize_per_tensor_6 = None
dense = dequantize_6.dense(); dequantize_6 = None
net_6_scale_0 = self.net_6_scale_0
net_6_zero_point_0 = self.net_6_zero_point_0
quantize_per_tensor_7 = torch.quantize_per_tensor(dense, net_6_scale_0, net_6_zero_point_0, torch.qint8); dense = net_6_scale_0 = net_6_zero_point_0 = None
flatten = torch.flatten(quantize_per_tensor_7, 1); quantize_per_tensor_7 = None
dequantize_8 = flatten.dequantize(); flatten = None
log_softmax = torch.nn.functional.log_softmax(dequantize_8, dim = 1, _stacklevel = 3, dtype = None); dequantize_8 = None
return log_softmax
```

torch.quantize_per_tensor doesn’t accept my SparseConvTensor, the correct code (for this custom class) should be

```
quantize_per_tensor = sparse_conv_tensor.replace_feature(torch.quantize_per_tensor(sparse_conv_tensor.features, _scale_0, _zero_point_0, torch.qint8))
```

How to solve these problems?

Edit:

problem 2 is solved by simple fx graph transform:

```
def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in m.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.quantize_per_tensor:
node.target = custom_quantize_per_tensor
m.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
m.recompile()
return m
```