Maxpool2d doesn't work in TorchScript after quantization?

I saved the quantized model to TorchScript successfully, but when I run the TorchScript model, I met the following problem:

Traceback (most recent call last):
  File "", line 293, in  <module>
    js_out = ts(x)
  File "/home/dai/py36env/lib/python3.6/site-packages/torch/nn/modules/", line 541, in __call__
    result = self.forward(*input, **kwargs)
RuntimeError: Didn't find kernel to dispatch to for operator 'aten::max_pool2d_with_indices'. Tried to look up kernel for dispatch key 'QuantizedCPUTensorId'. Registered dispatch keys are: [CUDATensorId, CPUTensorId, VariableTensorId]
The above operation failed in interpreter, with the following stack trace:
at <string>:63:30

            return torch.avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override), backward

        def max_pool2d(self,
                       kernel_size: List[int],
                       stride: List[int],
                       padding: List[int],
                       dilation: List[int],
                       ceil_mode: bool):
            output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
                              ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            def backward(grad_output):
                grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
                return grad_self, None, None, None, None, None
            return output, backward

        def max_pool2d_with_indices(self,
                                    kernel_size: List[int],
                                    stride: List[int],
                                    padding: List[int],

The above operation failed in interpreter, with the following stack trace:

Here’s my code for quantization:


    img ="3.png").convert("L").resize((300,32))
    x = ToTensor()(img).unsqueeze(0)
    out = crnn(x)
    crnn.cnn.qconfig = torch.quantization.get_default_qconfig('fbgemm')

    #  calibration

    # quantize
    torch.quantization.convert(crnn.cnn, inplace=True)

    quantized_out = crnn(x)

    ts = torch.jit.script(crnn)
    ts_out = ts(x)

The error occurs when the TorchScript model ts is called.

And here’s my code for crnn.cnn

class ConvBNReLU(nn.Sequential):
    def __init__(self,i,in_channels,out_channels,kernel_size=3,stride = 1,padding = 1,groups = 1,bn = False): = bn
        if bn:
                ("conv{0}".format(i),nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,groups = groups,bias = True)),
                ("conv{0}".format(i),nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,groups = groups,bias = True)),

class CNN(nn.Module):

    def __init__(self, imgH, nc):
        super(CNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        self.cnn = nn.Sequential()
        pn = 0
        for i in range(7):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]

            if i in [0,1,3,5]:
                if i in [0,1]:
                pn += 1 

        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.cnn(x)
        x = self.dequant(x)
        return x

My PyTorch version is 1.3.1+cpu.

How can I solve this problem? Looking forward to your replay.

The input tensor of CRNN is cropped from an tensor of image, when I convert the tensor to PIL image and convert it back, the code works well.
img_out = ToTensor()(ToPILImage()(img_out.squeeze(0))).unsqueeze(0)

and the following code also helps:

img_out = torch.ones(img_out_.shape) = img_out_.clone()

Could anyone tell me the differences between cropped tensor and tensor converted from image?

Hi @dalalaa, that sounds like a bug. Do you mind raising an issue on github ?

Thank you for your reply, I will try to find an minimum code to reproduce it.

I know this is late but for others visiting this with the same error I found that wrapping the ts(x) in a torch.no_grad() block solved the error