While inferencing a quantized model, RuntimeError occurred at torch.nn.quantized.FloatFunction summation operation

Error at

/usr/local/lib/python3.6/dist-packages/torch/nn/quantized/modules/functional_modules.py in add(self, x, y)
     43     def add(self, x, y):
     44         # type: (Tensor, Tensor) -> Tensor
---> 45         r = torch.add(x, y)
     46         r = self.activation_post_process(r)
     47         return r

Full error message

RuntimeError: Could not run 'aten::add.Tensor' with arguments from the 'QuantizedCPU' backend. 'aten::add.Tensor' is only available for these backends: [CPU, MkldnnCPU, SparseCPU, Meta, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /pytorch/build/aten/src/ATen/CPUType.cpp:2136 [kernel]
MkldnnCPU: registered at /pytorch/build/aten/src/ATen/MkldnnCPUType.cpp:144 [kernel]
SparseCPU: registered at /pytorch/build/aten/src/ATen/SparseCPUType.cpp:239 [kernel]
Meta: registered at /pytorch/aten/src/ATen/native/BinaryOps.cpp:1049 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: fallthrough registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
AutogradOther: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8041 [autograd kernel]
AutogradCPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8041 [autograd kernel]
AutogradCUDA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8041 [autograd kernel]
AutogradXLA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8041 [autograd kernel]
AutogradPrivateUse1: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8041 [autograd kernel]
AutogradPrivateUse2: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8041 [autograd kernel]
AutogradPrivateUse3: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8041 [autograd kernel]
Tracer: registered at /pytorch/torch/csrc/autograd/generated/TraceType_2.cpp:9726 [kernel]
Autocast: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:254 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/BatchingRegistrations.cpp:531 [kernel]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

Code snippet


    def forward(self, x, x2=None, x3=None):
        x_size = x.size()
        resl = x
        for i in range(len(self.pools_sizes)):

            y = self.convs[i](self.pools[i](x))
            q_add0 = FloatFunctional()
#error is because of this line below
            resl = q_add0.add(resl, nn.functional.interpolate(y, x_size[2:], mode='bilinear', align_corners=True))  #error is because of this line
        resl = self.relu(resl)
        if self.need_x2:
            
            resl = nn.functional.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True)
        resl = self.conv_sum(resl)
        if self.need_fuse:
            q_add1 = FloatFunctional()
            q_add2 = FloatFunctional()
            resl = self.conv_sum_c(q_add1.add(q_add2.add(resl, x2), x3))
        return resl

I tried to do as mentioned in post.

If I eval the quantized model the add operation is looks like

(conv_sum): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=2.668934655503108e-07, zero_point=66, padding=(1, 1), bias=False)

(conv_sum_c): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=1.018745024339296e-06, zero_point=56, padding=(1, 1), bias=False)

In case, I use torch.nn.quantized. QFunctional. The model will not be quantized. The error would something like from CPU backend to QuantizedCPU backend is not possible.

Any idea! Why?

Inference code

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

valdir = '/content/test'

dataset_test = torchvision.datasets.ImageFolder(
    valdir,
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]))

test_sampler = torch.utils.data.SequentialSampler(dataset_test)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1,
    sampler=test_sampler)

model.load_state_dict(torch.load('/content/model.pth'))
model.eval()

with torch.no_grad():
    for image, target in data_loader_test:
     
        print(image.size()) #torch.Size([1, 3, 224, 224])
        output = model(image)
        print(output)

I think the inference code is fine. Problem is in add operation. Please give some ideas on how to solve this.

1 Like

How are you quantizing? Did you setup qconfigs for FloatFunctionals (q_add0, q_add1 and q_add2) as well? If you set up qconfigs for these, they will get converted to quantized::add and this op will work on quantized tensor.

1 Like

Yea, it looks like the model was not converted to quantized version correctly. As @dskhudia mentioned, make sure you have the qconfigs in all the layers that need to be quantized

1 Like

Actually, I see what your problem is. You are building the model incorrectly: the FloatFunctional is a “stateful” layer that needs to be initialized in the model constructor. Otherwise, it will not be visible to the convert script. Here is how you can rewrite the model (just an example):

def __init__(self):
        super().__init__()
        # ... Any other definitions
        self.q_add0 = FloatFunctional()
        self.q_add1 = FloatFunctional()
        self.q_add2 = FloatFunctional()
        # ... Any other definitions

def forward(self, x, x2=None, x3=None):
        x_size = x.size()
        resl = x
        for i in range(len(self.pools_sizes)):

            y = self.convs[i](self.pools[i](x))
            # q_add0 = FloatFunctional()
#error is because of this line below
            resl = self.q_add0.add(resl, nn.functional.interpolate(y, x_size[2:], mode='bilinear', align_corners=True))  #error is because of this line
        resl = self.relu(resl)
        if self.need_x2:
            
            resl = nn.functional.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True)
        resl = self.conv_sum(resl)
        if self.need_fuse:
            # q_add1 = FloatFunctional()
            # q_add2 = FloatFunctional()
            resl = self.conv_sum_c(self.q_add1.add(self.q_add2.add(resl, x2), x3))
        return resl
1 Like

ok. Got it. Thank you for replying. It has solved my problem.

One last thing, do I have to create all summation operation unique in for loop?
For example,

        resl = self.q_add00.add(resl, z0)
        resl = self.q_add01.add(resl, z1)   
        resl = self.q_add02.add(resl, z2)

If I do like above mentioned, that part of the model will look something like:

      (q_add00): QFunctional(
        scale=1.027651309967041, zero_point=67
        (activation_post_process): Identity()
      )
      (q_add01): QFunctional(
        scale=1.0117942094802856, zero_point=68
        (activation_post_process): Identity()
      )
      (q_add02): QFunctional(
        scale=0.9806106686592102, zero_point=74
        (activation_post_process): Identity()
      )

I do not know what is wrong but the quantized model has 0 % accuracy. That is why I am trying different approaches.

Quantized aware training can be one solution. But, what are the other parameters/procedures to check/apply in Post-training static quantization for better accuracy for ResNet-50 backend models?

I thank you in advance.

Thank you for replying.
I am currently simply using quantize_model(model, 'fbgemm'). I have also tried taking down the quantize_model and executed it line by line. In both cases, QFunctional had never appeared in the model. I did not realize it supposed to be a layer.

I am sorry but what did you mean by,

I am not sure what is this script? Is it from the tutorials? Either way, if you follow the steps as shown in the static quantization tutorial, it should convert your model to quantized version. As for the accuracy, it could be that low if you don’t calibrate your model before quantizing it: (beta) Static Quantization with Eager Mode in PyTorch — PyTorch Tutorials 2.1.1+cu121 documentation

1 Like

Thank you for your reply. It has mostly solved my problem.
I want to quantized available salient object detection models.

is one of them.