Is there a way to quantize conv_transpose2d layer?

Hi, @rfejgin According to official tutorial https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html#model-architecture, dequant->conv2d_transpose->quant is a standard process for quantization, it doesn’t tell pytorch which backend it should use.
Could you try setting qconfig to QNNPACK as @Zafar mentioned in step 2) ?

Okay, I think I figured out my problem. I was setting the qconfig correctly

backend = "fbgemm"                
m.qconfig = torch.quantization.get_default_qconfig(backend)

but didn’t realize that I also needed to choose the backend, like so:

torch.backends.quantized.engine = backend

With this change I am able to get through the quantization conversion (convert()) using FBGEMM and conv_transpose2d. So maybe they did add FBGEMM support after all?

Thanks @ruka for the pointers and suggestions, that tutorial was helpful.

1 Like

Hmm, actually turns out I’m still getting the error. Not sure what is making it come and go (I’ve also been playing around with the ‘inplace’ parameter), will try to narrow down.

You getting a new error or still ‘FBGEMM doesn’t support transpose packing yet’ ?

Still “FBGEMM doesn’t support transpose packing yet”.

It appears that as long as a module being quantized contains a conv_transpose2d module, quantization will fail when using the FBGEMM backend, with the error
RuntimeError: FBGEMM doesn't support transpose packing yet!.

This happens even if the forward() function never calls the conv_transpose2d! And then it’s not surprising that wrapping the transposed convolution using dequant->conv_transpose2d->quant also doesn’t work.

Here’s a fairly minimal example:

import torch
import torch.nn as nn
import torch.quantization
from torch.quantization import QuantStub, DeQuantStub

class UpSample(nn.Module):
    def __init__(self, C):
        super(UpSample, self).__init__()
        self.transpose_conv2d = nn.ConvTranspose2d(in_channels=C, out_channels=C, kernel_size = (2,2))
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        # Enabling the following 3 lines doesn't help. 
        #x = self.dequant(x)
        #x = self.transpose_conv2d(x)
        #x = self.quant(x)

        # Even if forward() doesn't call the transposed convolution, we get an error.
        return x

C = 10
model = UpSample(C=C)
x = torch.ones(1, C, 3, 1)

# Set up quantization
backend = 'fbgemm'
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

# Insert observers
model = torch.quantization.prepare(model, inplace=True)
# Calibrate
_ = model(x)
# Quantize
m = torch.quantization.convert(model, inplace=True)

Result:

Exception has occurred: RuntimeError
FBGEMM doesn't support transpose packing yet!

@Zafar, @jerryzh168: do you know of any way to work around fact that conv_transpose2d cannot be quantized when using the FBGEMM backend? I don’t mind keeping that layer unquantized, but there doesn’t seem to be a way to get around the above error, even if dequantizing before the transposed conv. Thank you - any advice would be appreciated!

I think the fbgemm convtranspose is landed – can you try again?

I’m on 1.7.0, did you mean it is landed on latest master? I can try that.

Also, for anyone else trying to wrap an unsupported layer type with dequant-layer-quant, the following appears to work, at least in the sense that quantization completes without error:

class UpSample(nn.Module):
    def __init__(self, C):
        super(UpSample, self).__init__()
        self.transpose_conv2d = nn.ConvTranspose2d(in_channels=C, out_channels=C, kernel_size = (2,2))

        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        # disable quantization on unsupported layers types
        self.transpose_conv2d.qconfig = None

    def forward(self, x):
        x = self.dequant(x)
        x = self.transpose_conv2d(x)
        x = self.quant(x)
        return x

I tried the latest nightly build and indeed it looks like conv_transpose2d is supported now even on FBGEMM. But I am getting this error:
Per Channel Quantization is currently disabled for transposed conv

My network is using the default qconfig for FBGEMM, and the inputs to the network are quantized with a QuantStub() at the start of the network. Is there something special I should be doing to choose a different quantization type? thank you.

This is odd – I thought I landed the change. I will take a look later this week

1 Like

@Zafar thank you for all your work!

I’ve just tried to quantize a model with ConvTranspose1d using torch 1.8 and fbgemm backend, but got this error message: “AssertionError: Per channel weight observer is not supported yet for ConvTranspose{nx}d.”

Could you maybe prompt when this observer will be released, is it in progress?

At the moment there is no active work to implement the per channel observer for the convtranspose. The reason is that there is non-trivial task that requires observation of a proper channel, which is different for the conv and convtranspose. If you add a feature request on github, I will try to get to it as soon as I can. Meanwhile, you should use per tensor configuration.

thank you for a quick response! Added an issue

1 Like