Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend

I run into this error when trying to use tf_efficientnet_b3_ns from timm. More specifically, it boils down to this snippet:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Conv2d):
    def __init__(self):
        super().__init__(3, 1, 3)

    def _conv_forward(self, x, weight, bias):
        return F.conv2d(x, weight, bias)

    def forward(self, x):
        return self._conv_forward(x, self.weight, self.bias)


class QuantizedModel(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.model_fp32 = model_fp32

    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x

model_fp32 = Model()
model_q = QuantizedModel(model_fp32)
model_q.qconfig = torch.quantization.get_default_qconfig("fbgemm")
torch.quantization.prepare(model_q, inplace=True)
model_q(torch.zeros(size=(1,3,64,64)))
model_q = torch.quantization.convert(model_q, inplace=True)
model_q(torch.zeros(size=(1,3,64,64)))

The key piece that’s translated from the timm library is the use of a module that inherits from nn.Conv2d. Here’s the stack trace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-14-894b9494d2e2> in <module>
      1 model_q.eval()
----> 2 q_latency = measure_inference_latency(model_q, input_size = [1, 3] + CFG.image_dim, num_samples=num_samples)

<ipython-input-13-314093bf255b> in measure_inference_latency(model, input_size, num_samples, num_warmups)
      6     with torch.no_grad():
      7         for _ in range(num_warmups):
----> 8             _ = model(x)
      9     # torch.cuda.synchronize()
     10 

project/venv/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-1-9c1a0796963a> in forward(self, x)
     71     def forward(self, x):
     72         x = self.quant(x)
---> 73         x = self.model_fp32(x)
     74         x = self.dequant(x)
     75         return x

project/venv/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

project/venv/lib/python3.8/site-packages/timm/models/efficientnet.py in forward(self, x)
    389 
    390     def forward(self, x):
--> 391         x = self.forward_features(x)
    392         x = self.global_pool(x)
    393         if self.drop_rate > 0.:

project/venv/lib/python3.8/site-packages/timm/models/efficientnet.py in forward_features(self, x)
    379 
    380     def forward_features(self, x):
--> 381         x = self.conv_stem(x)
    382         x = self.bn1(x)
    383         x = self.act1(x)

project/venv/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

project/venv/lib/python3.8/site-packages/timm/models/layers/conv2d_same.py in forward(self, x)
     28 
     29     def forward(self, x):
---> 30         return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
     31 
     32 

project/venv/lib/python3.8/site-packages/timm/models/layers/conv2d_same.py in conv2d_same(x, weight, bias, stride, padding, dilation, groups)
     15         padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
     16     x = pad_same(x, weight.shape[-2:], stride, dilation)
---> 17     return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
     18 
     19 

RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend. 'aten::thnn_conv2d_forward' is only available for these backends: [CPU, CUDA, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /pytorch/build/aten/src/ATen/CPUType.cpp:2127 [kernel]
CUDA: registered at /pytorch/build/aten/src/ATen/CUDAType.cpp:2983 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:7974 [autograd kernel]
AutogradCPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:7974 [autograd kernel]
AutogradCUDA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:7974 [autograd kernel]
AutogradXLA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:7974 [autograd kernel]
AutogradPrivateUse1: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:7974 [autograd kernel]
AutogradPrivateUse2: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:7974 [autograd kernel]
AutogradPrivateUse3: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:7974 [autograd kernel]
Tracer: registered at /pytorch/torch/csrc/autograd/generated/TraceType_0.cpp:9341 [kernel]
Autocast: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:254 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/BatchingRegistrations.cpp:511 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

Note that when I instead use model_fp32 = nn.Conv2d(3, 1, 3) (instead of model_fp32 = Model()) the code runs as expected.