Error when exporting quantized model to onnx

I am trying to export quantized model to onnx and have following error message:

``RuntimeError: false INTERNAL ASSERT FAILED at “…/aten/src/ATen/quantized/Quantizer.cpp”:439, please report a bug to PyTorch. cannot call qscheme on UnknownQuantizer

Originally my model is prepared in fx mode, but I have the same error with this dummy model. PyTorch version is 1.13.0.dev20220818+cu102

class QuantizedDummyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.conv = torch.nn.Conv2d(1, 16, (3,3))
        self.prelu = torch.nn.PReLU(16)
        self.prelu.weight.data = torch.tensor(np.arange(16), dtype=torch.float32)
       

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.prelu(x)
        x = self.dequant(x)
        return x

dummy = QuantizedDummyModel()
dummy.train()
dummy.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(dummy, inplace=True)

criterion = torch.nn.MSELoss().cuda()
optimizer = torch.optim.Adam(dummy.parameters(),lr=1e-3)
dummy.cuda()
dummy.train()
for j in range(5):
    tensor = torch.randn(16,1,16,16).cuda()
    dummy_gt = torch.randn(16,16,14,14).cuda()
    output = dummy(tensor)
    loss = criterion(output, dummy_gt)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
qdummmy = torch.quantization.convert(dummy.cpu().eval(), inplace=False)
inp = torch.randn(1,1,16,16)
torch.onnx.export(qdummmy, 
                          inp,
                          'foo.onnx',
                          input_names = ['input'],
                          dynamic_axes = {'input' : {0 : "batch"}},
                          opset_version=13,                          
                          verbose = False)

Could you please help me?

why do you have UnknownQuantizer? can you print the model?

Here is the fake quantized model

QuantizedDummyModel(
  (quant): QuantStub(
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (dequant): DeQuantStub()
  (conv): Conv2d(
    1, 16, kernel_size=(3, 3), stride=(1, 1)
    (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
      (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (prelu): PReLU(
    num_parameters=16
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
)

and here is the quantized model

QuantizedDummyModel(
  (quant): Quantize(scale=tensor([0.0640]), zero_point=tensor([61]), dtype=torch.quint8)
  (dequant): DeQuantize()
  (conv): QuantizedConv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), scale=0.04176247864961624, zero_point=60)
  (prelu): QuantizedPReLU()
)

Also, if I change PReLU to ReLU the problem is gone.