Is there a way to quantize conv_transpose2d layer?

Following is my error message:

Traceback (most recent call last):
File “pose_estimation/test_on_single_image_quant_ver.py”, line 119, in
main()
File “pose_estimation/test_on_single_image_quant_ver.py”, line 92, in main
output = quantized_model(input)
File “/usr/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 722, in _call_impl
result = self.forward(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File “code/torch/models/pose_mobilenet.py”, line 17, in forward
x0 = (self.quant).forward(x, )
x1 = (self.features).forward(x0, )
x2 = (self.conv_transpose_layers).forward(x1, )
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <— HERE
x3 = (self.final_layer).forward(x2, )
return (self.dequant).forward(x3, )
File “code/torch/torch/nn/modules/container/___torch_mangle_4.py”, line 26, in forward
_8 = getattr(self, “8”)
input0 = (_0).forward(input, None, )
input1 = (_1).forward(input0, )
~~~~~~~~~~~ <— HERE
input2 = (_2).forward(input1, )
input3 = (_3).forward(input2, None, )
File “code/torch/torch/nn/modules/container/___torch_mangle_4.py”, line 25, in forward
_7 = getattr(self, “7”)
_8 = getattr(self, “8”)
input0 = (_0).forward(input, None, )
~~~~~~~~~~~ <— HERE
input1 = (_1).forward(input0, )
input2 = (_2).forward(input1, )
File “code/torch/torch/nn/modules/conv.py”, line 22, in forward
output_size: Optional[List[int]]=None) → Tensor:
output_padding = (self)._output_padding(input, output_size, [2, 2], [1, 1], [4, 4], )
_0 = torch.conv_transpose2d(input, self.weight, self.bias, [2, 2], [1, 1], output_padding, 1, [1, 1])
~~~~~~~~~~~~~~~~~~~~~~ <— HERE
return _0
def _output_padding(self: torch.torch.nn.modules.conv.ConvTranspose2d,

Traceback of TorchScript, original code (most recent call last):
File “/usr/lib/python3.8/site-packages/torch/nn/modules/container.py”, line 117, in forward
def forward(self, input):
for module in self:
input = module(input)
~~~~~~ <— HERE
return input
File “/usr/lib/python3.8/site-packages/torch/nn/modules/container.py”, line 117, in forward
def forward(self, input):
for module in self:
input = module(input)
~~~~~~ <— HERE
return input
File “/usr/lib/python3.8/site-packages/torch/nn/modules/conv.py”, line 905, in forward
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)

    return F.conv_transpose2d(
           ~~~~~~~~~~~~~~~~~~ <--- HERE
        input, self.weight, self.bias, self.stride, self.padding,
        output_padding, self.groups, self.dilation)

RuntimeError: Could not run ‘aten::slow_conv_transpose2d’ with arguments from the ‘QuantizedCPU’ backend. ‘aten::slow_conv_transpose2d’ is only available for these backends: [CPU, CUDA, Autograd, Profiler, Tracer].

I’m new to this, I try to find something like torch.nn.quantized.conv_transpose2d but I can’t find it or is there any other ways?
Thanks in advance

hi @ruka, we landed support for quantized conv transpose recently (https://github.com/pytorch/pytorch/pull/40371 and the preceding PRs). It is not in v1.6, but you can try it out in the nightly!

Thank you so much! I will try it. :smiley: @Vasiliy_Kuznetsov

1 Like

Hi, @Vasiliy_Kuznetsov I updated my pytorch to nightly(Version: 1.7.0a0+60665ac).
But when I try to convert my model, I get error:

Traceback (most recent call last):
File “pose_estimation/quantized.py”, line 67, in
main()
File “pose_estimation/quantized.py”, line 61, in main
torch.quantization.convert(model, inplace = True)
File “/home/yjwen/local/lib/python3.8/site-packages/torch/quantization/quantize.py”, line 414, in convert
_convert(module, mapping, inplace=True)
File “/home/yjwen/local/lib/python3.8/site-packages/torch/quantization/quantize.py”, line 458, in _convert
_convert(mod, mapping, inplace=True)
File “/home/yjwen/local/lib/python3.8/site-packages/torch/quantization/quantize.py”, line 459, in _convert
reassign[name] = swap_module(mod, mapping)
File “/home/yjwen/local/lib/python3.8/site-packages/torch/quantization/quantize.py”, line 485, in swap_module
new_mod = mapping[type(mod)].from_float(mod)
File “/home/yjwen/local/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py”, line 507, in from_float
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
File “/home/yjwen/local/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py”, line 641, in init
super(ConvTranspose2d, self).init(
File “/home/yjwen/local/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py”, line 476, in init
super(_ConvTransposeNd, self).init(
File “/home/yjwen/local/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py”, line 53, in init
self.set_weight_bias(qweight, bias_float)
File “/home/yjwen/local/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py”, line 650, in set_weight_bias
self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(
RuntimeError: FBGEMM doesn’t support transpose packing yet!

Did I miss anything(maybe some special that I need to do before quantized conv transpose) or this is a bug?
Thanks

Currently, the ConvTranspose is only supported using the QNNPACK. The FBGEMM version is planned, but there is no specific date specified for it. Meanwhile, you have two options for the eager mode: replace the ConvTranspose: 1) Replace the instances of the ConvTranspose with dequant->ConvTranspose->quant construct 2) Set the torch.backends.quantized.engine = 'qnnpack' before running your model. You also might need to set the qconfig = torch.quantization.get_default_qconfig('qnnpack') or qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')

@Zafar Thank you so much for the reply!! I successfully converted the model by setting quantized.engine = 'qnnpack' and get_default_qconfig('qnnpack')
But the quantized model predicts a totally wrong result(my original model works fine)
I carefully studied the official tutorial https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html#model-architecture
It seems that nothing special needs to be done when using a quantized model, just the usual way

model.eval()
with torch.no_grad():
    output = model(image)

Any hints?

Can you elaborate on the wrong result, please – I wonder if it is within the quantization error.

Hi, @Zafar
This is the result from my original model


This is the result from the quantized model

It seems that all 21 keypoints are in the wrong position

Following is my conversion code

    state_dict = torch.load('model_best.pth')
    model.load_state_dict(state_dict)
    model = model.to('cpu')

    torch.backends.quantized.engine = 'qnnpack'
    
    model.eval()
    model.fuse_model()
    model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
    torch.quantization.prepare(model, inplace = True)
    evaluate(model)
    torch.quantization.convert(model, inplace = True)
    torch.jit.save(torch.jit.script(model), 'model_quantization_scripted_quantized.pth')

I am assuming the evaluate runs the model on the training data.

Anyway, the results seem to be pretty bad. It would worth to investigate the SNR of the model. Is it by any chance an open-source model I could take a look at? If not – it’s OK, I can try cooking up my own model and test it – what’s the architecture?

@Zafar Sure, let me warp up my model definition code and delete some unnecessary dependencies.

@Zafar Hi, I upload the code, you can find it here
https://drive.google.com/file/d/1tyS4lqdq9FWxy-96M5qwJjZno7ViB83p

Thank you for the model – Is there an open-source dataset I could use to reproduce the error? I ran the model with synthetic data, but it is better to repro with actual images. Ideally, I would want the pretrained model as well, but if it is not available – it’s OK, I can train it myself.

I can run the model with random inputs, and get fairly good results: https://colab.research.google.com/drive/1T_jvh96gekf1OLh_ttbT3Tgi6VNTytTF?usp=sharing

1 Like

Hi, @Zafar Thank you for the reply. I use this dataset as my training data:
https://www.cs.cmu.edu/~tsimon/projects/mvbs.html
It seems that the original download link is dead, so I upload a torrent here:
https://drive.google.com/file/d/1QDJdFuYJGS9Kp4lng_bbv5JF36l4tioc/view?usp=sharing

BTW, for the result of SNR shown in your code, does it mean that, a small number shows the result between original model and quantized model are close to each other?

@ruka – For the SNR metric – the higher is the better and I usually go with the rule of thumb of 15-20dB is good for quantization. However, this number might not be good for all the models. I’ll take a look at the model + data you sent

@Zafar Thank you so much :smiley:

@Zafar, I am facing the same issue. It seems to be a Keypoints issue. I trained same model on Keypoints data and classification data ( just last layer changed). I got following results on quantization.

Trained for Keypoints:
Weights SNR: 40 - 47 (almost perfect)
model outputs stats SNR: 23(first layer) - 0.01 (end layers). The value kept on reducing with every layer. Why there is no correlation between weights and output.
Quantized model accuracy: 0.01

Trained for Classification:
Weights SNR: 43-48
model outputs stats SNR: 10-22 (random)
Float model accuracy: 92%
quantized model accuracy: 91.6%

I will be waiting for your analysis.

@Mamta, @ruka Thank you for reporting this – I will take a look at these models. I currently have a similar issue for the generator models – not sure if it is related, but I will be out looking at this issue more closely

hi @Zafar: I am running into this error when trying to use a ConvTranspose2d with a FBGEMM backend. I am trying to implement workaround (1) that you suggested, i.e. to wrap the ConvTranspose2d with dequant and quant steps, but am struggling to get it right. That is, even if I run those steps, the error (‘FBGEMM doesn’t support transpose packing yet!’) appears. Do you know of an example of how to implement such wrapping that would allow quantization to go through? Thank you!

What I did was to create QuantStub and DeQuantStub instances during initialization, and then during the forward() did

        x = self.dequant(x)
        x = self.transpose_conv2d(x)
        x = self.quant(x)

… but this still gives the same error.

Hi, @rfejgin , may I know which version of pytorch you are using? Did they announce that ConvTranspose2d is supported on FBGEMM backend in their latest release? As far as I know, about a month ago, ConvTranspose2d is still not supported on FBGEMM backend yet.

I’m on 1.7.0. Yes, I’m aware that ConvTranspose2d is not supported on the FBGEMM backend, so I was trying the workaround suggested by @Zafar, which is to wrap the calls to the module with dequant->conv2d_transpose->quant, but couldn’t get that to work.