Dear all,
I’m trying to quantize only some layers of a ResNet50. Therefore, I have implemented my own QuantizedResNet50 class moving the position where QuantStub and DeQuantStub are applied according to the quantisation configuration. After conversion, it gives me the following error:
NotImplementedError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv2d.new' is only available for these backends: [QuantizedCPU, QuantizedCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].
If I move back the position of quant and dequant everything work. Can you help me with this problem? Do you know how can I do partial quantization?
Below you can find the code to reproduce the error.
Thanks in advance!
from typing import Any, List, Type, Union
import torch
import torch.nn as nn
from torch import Tensor
from torchvision.models.resnet import (
BasicBlock,
Bottleneck,
ResNet,
)
class QuantizableBasicBlock(BasicBlock):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.add_relu = torch.nn.quantized.FloatFunctional()
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.add_relu.add_relu(out, identity)
return out
class QuantizableBottleneck(Bottleneck):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.skip_add_relu = nn.quantized.FloatFunctional()
self.relu1 = nn.ReLU(inplace=False)
self.relu2 = nn.ReLU(inplace=False)
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu2(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.skip_add_relu.add_relu(out, identity)
return out
class QuantizableResNet(ResNet):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.quant(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.dequant(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def _resnet(
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
layers: List[int],
**kwargs: Any,
) -> QuantizableResNet:
model = QuantizableResNet(block, layers, **kwargs)
return model
def ResNet50()-> QuantizableResNet:
return _resnet(QuantizableBottleneck, [3, 4, 6, 3])
import torch
from resnet50 import ResNet50
modelq = ResNet50()
modelq.fc = torch.nn.Linear(modelq.fc.in_features, 2)
my_qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
for nam, l in modelq.named_modules():
if 'layer3' in nam and ('conv' in nam or 'bn' in nam or 'relu' in nam or 'quant' in nam or 'downsample' in nam):
l.qconfig = my_qconfig
elif 'layer4' in nam and ('conv' in nam or 'bn' in nam or 'relu' in nam or 'quant' in nam or 'downsample' in nam):
l.qconfig = my_qconfig
elif 'layer2' in nam and ('conv' in nam or 'bn' in nam or 'relu' in nam or 'quant' in nam or 'downsample' in nam):
l.qconfig = my_qconfig
torch.quantization.prepare_qat(modelq, inplace=True)
modelq.eval()
modelq.to("cpu")
torch.quantization.convert(modelq, inplace=True)
input = torch.rand((1,3,700,700))
modelq(input)