Could not run 'quantized::conv2d.new' with arguments from the 'QuantizedCUDA' backend

x=self.quant(x)
x=self.conv(x)
x=self.bn(x)
x=self.act(x)
x=self.dequant(x)

I trained a QAT model and when i tried evaluating the model, i got the error.

Could not run ‘quantized::conv2d.new’ with arguments from the ‘QuantizedCUDA’ backend … … ‘quantized::conv2d.new’ is only available for these backends: [QuantizedCPU, …].

when i added x = x.to(‘cpu’) before x = self.quant(x), to make it a QuantizedCPU backend (note that doing so, i am unable to train the model again as i will get:

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

which is another problem…?), i will then get:

Could not run ‘aten::silu.out’ with arguments from the ‘QuantizedCPU’ backend. This could be because the operator doesn’t exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit Internal Login for possible resolutions. ‘aten::silu.out’ is only available for these backends: [CPU, CUDA,…

so i changed position of dequant to
x=x.to(‘cpu’)
x=self.quant(x)
x=self.conv(x)
x=self.bn(x)
x=self.dequant(x)
x=self.act(x)

I get the error pointing to x=self.quant(x) :

Could not run ‘aten::quantize_per_tensor’ with arguments from the ‘QuantizedCPU’ backend. This could be because the operator doesn’t exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit Internal Login for possible resolutions. ‘aten::quantize_per_tensor’ is only available for these backends: [CPU, CUDA,

and if i remove x = self.quant(x), i get back:

Could not run ‘quantized::conv2d.new’ with arguments from the ‘CPU’ backend. This could be because the operator doesn’t exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit Internal Login for possible resolutions. ‘quantized::conv2d.new’ is only available for these backends: [QuantizedCPU,

Please help as i’ve been encountering errors after errors even after searching online for solutions.

Hi, I’m also getting same your problem. If possible, Could you give me some solutions to fix this error?

Hi, I have not found the solution
cc @jerryzh168

this error occurs when you try to run a quantized op with weight or input on cuda. For example if you were to take a correctly quantized model and then do .to(‘cuda’) and then run the model, you’d get this error.

based on the second error message, your weight is on cuda. Note: changing where x.to(‘cpu’) is located will not fix this problem if the actual op weight is on cuda.

my guess is that somewhere in your code you have model.to(‘cuda’) (likely during training) and you are not converting it back to cpu i.e. model.to(‘cpu’) before trying to do quantization.

addtionally, it looks like your self.act op is aten::silu which isn’t being converted to a quantized op (looks like it doesn’t have a quantized implementation https://github.com/pytorch/pytorch/blob/master/torch/nn/quantized/modules/activation.py). You can either implement it yourself or change to something along the lines of

y = sigmoid(x)
x = y * x

I would also maybe start with a less weird model and make sure the flow works for you before iterating on that. Something like: (beta) Static Quantization with Eager Mode in PyTorch — PyTorch Tutorials 2.1.1+cu121 documentation could be a good starting point.

my guess is that somewhere in your code you have model.to(‘cuda’) (likely during training) and you are not converting it back to cpu i.e. model.to(‘cpu’) before trying to do quantization.

Strange because I have done model.to(‘cpu’) before torch.quantization.convert(model)

you can inspect the model and identify whether the weight is stored correctly, its possible its not transfering over or something, though usually modules move over their attributes by default.

Hi @HDCharles and @MrOCW :slight_smile: !
Did you find any solutions to this issue?
I am facing the same error and could not find any solution yet.

Also, I checked in my code, I always have model_static_quantized.to(torch.device('cpu')) or model.to(torch.device('cpu')). So everything is on CPU and not CUDA.
How can I identify if the weights are stored correctly or not please? @HDCharles

Thanks in advance for your help!
any additional inputs @jerryzh168?

is your model input a cuda tensor?

This is a user error, it means quantization is not used correctly, in eager mode, user are expected to reason about the quantized model and is expected to place QuantStub, DeQuantStub and set qconfigs correctly. We may be able to provide more help if you post your modified model for quantization and qconfig settings.

Hi @jerryzh168!
The error I have is NotImplementedError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend

The inference is done on the following input:
dummy_input = torch.rand((1, 3, 512, 512), dtype=torch.float32, device="cpu")

The modified model for quantization is:

# Define the Model
class ResNet(nn.Module):
    def __init__(self, config, output_dim):
        super().__init__()

        # QuantStub converts tensors from floating point to quantized
        # self.quant = torch.quantization.QuantStub()

        block, n_blocks, channels = config
        self.in_channels = channels[0]

        assert len(n_blocks) == len(channels) == 4

        self.conv1 = nn.Conv2d(
            3,
            self.in_channels,
            kernel_size=(7, 7),
            stride=(2, 2),
            padding=3,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride=2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride=2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.in_channels, output_dim)

        # DeQuantStub converts tensors from quantized to floating point
        # self.dequant = torch.quantization.DeQuantStub()




    def get_resnet_layer(self, block, n_blocks, channels, stride=1):

        layers = []

        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False

        layers.append(block(self.in_channels, channels, stride, downsample))

        for i in range(1, n_blocks):
            layers.append(block(block.expansion * channels, channels))

        self.in_channels = block.expansion * channels

        return nn.Sequential(*layers)

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        # x = self.quant(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        h = x.view(x.shape[0], -1)

        x = self.fc(h)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        # x = self.dequant(x)

        return x, h


class BasicBlock(nn.Module):

    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=False):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()

        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)


        if downsample:
            conv = nn.Conv2d(
                in_channels, out_channels, kernel_size=1, stride=stride, bias=False
            )
            bn = nn.BatchNorm2d(out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None

        self.downsample = downsample
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)

        i = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            i = self.quant(i)
            i = self.downsample(i)
            i = self.dequant(i)

        x += i
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x


ResNetConfig = namedtuple("ResNetConfig", ["block", "n_blocks", "channels"])

resnet18_config = ResNetConfig(
    block=BasicBlock, n_blocks=[2, 2, 2, 2], channels=[64, 128, 256, 512]
)

About the qconfig settings:

backend = "fbgemm"  # x86 machine
torch.backends.quantized.engine = backend
model_q.qconfig = torch.quantization.get_default_qconfig(backend)

Depending on where I place QuantStub, DeQuantStub, I obtain this error or the one mentioned above.
this error

Thanks a lot for your help :slight_smile:
Please let me know if you need any additional inputs

1 Like

hi @sarramrg , have you checked out Quantization — PyTorch master documentation ? It has some example code on what this error means and how to resolve it.

Also, have you considered using FX graph mode quantization (Quantization — PyTorch master documentation )? As long as your model is symbolically traceable, this workflow will place the quant/dequant for you automatically.

1 Like
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
import torch
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
qmodel = torch.quantization.prepare(model,inplace=False)
qmodel.eval()
qmodel = torch.quantization.convert(qmodel)

This is the model which I used, later got the above error, can you pls help?

qmodel(batch)

I think you’ll need to get resnet50 from vision/torchvision/models/quantization at main · pytorch/vision · GitHub instead

1 Like