NotImplementedError: Could not run 'aten::_slow_conv2d_forward' with arguments from the 'QuantizedCPU' backend

I was trying to quantize my TSN model (ResNet + Classification head) I mainly used model structure from here: mmaction2/README.md at master · open-mmlab/mmaction2 · GitHub

I used Eager mode post training static quantization. After convert(), I got the error when I tried to do inference.

NotImplementedError: Could not run 'aten::_slow_conv2d_forward' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_slow_conv2d_forward' is only available for these backends: [CPU, CUDA, BackendSelect, Python, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, AutocastCPU, Autocast, Batched, VmapMode, Functionalize].

Quantized model is like this:

Recognizer2D(
  (quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
  (backbone): ResNet(
    (conv1): ConvModule(
      (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    
      (activate): ReLU(inplace=True)
    )
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    (layer2): Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (3): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    (layer3): Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): QuantizedBatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (3): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (4): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (5): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    (layer4): Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): QuantizedBatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): QuantizedBatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
  )
  (cls_head): TSNHead(
    (loss_cls): CrossEntropyLoss()
    (consensus): AvgConsensus()
    (avg_pool): AdaptiveAvgPool2d(output_size=(1, 1))
    (dropout): QuantizedDropout(p=0.4, inplace=False)
    (fc_cls): QuantizedLinear(in_features=2048, out_features=101, scale=0.05660918354988098, zero_point=47, qscheme=torch.per_channel_affine)
  )
)

Could some help me please? Thanks!

Hi Nanton,

It would appear that you are passing a quantized tensor into a kernel expecting floating point values. Usually the way to fix this is to add a torch.quantization.DeQuantStub before the offending op. This link provides an example of how to fix it. Please let me know if this fixes your problem.

Best,
-Andrew

Hi Andrew,

Thanks for your help. I found there’re a few reasons for this error. It’s possibly caused by operation not supported by the kernel, which you have mentioned. It’s also possibly caused by different versions of pytorch. In my case here, I downgraded pytorch to 1.7 and the problem was gone.

Best,
Nanton

1 Like

Can you tell me about your model architecture, the version of torch and torchvision. I got the same problem

this is caused when you call the forward function of a conv2d module with a quantized tensor (the output of a quant stub) after quantizing the model.

If your model is structured like

init:
self.conv = nn.Conv2d(...)
self.relu = nn.ReLU()
self.batch_norm = nn.BatchNorm2d(...)
self.quant_stub = nn.quant_stub()
self.dequant_stub = nn.dequant_stub()

forward:
x = self.quant_stub(x)
x = self.conv1(x)
x = self.relu(x)
x  = self.batch_norm(x)
x = self.dequant_stub(x)

and you apply quantization in a way that quantized the quant_stub but not the conv2d module, you’ll see this error. It looks like OP quantized only the BatchNorm operation?? Thus he saw this error.

If you want to only quantize the batchnorm operation you would have to move the quant stubs around so they surround the batchnorm for that to work.