Model quantization fails, but the network architecture looks OK

I have the following model definition, which I know for a fact quantized without any problems circa November 2020:

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.q_1 = torch.quantization.QuantStub()
        self.conv_1_1 = nn.Conv2d(3, 64, 3)
        torch.nn.init.kaiming_normal_(self.conv_1_1.weight)
        self.relu_1_2 = nn.ReLU()
        self.norm_1_3 = nn.BatchNorm2d(64)
        self.conv_1_4 = nn.Conv2d(64, 64, 3)
        torch.nn.init.kaiming_normal_(self.conv_1_4.weight)
        self.relu_1_5 = nn.ReLU()
        self.norm_1_6 = nn.BatchNorm2d(64)
        self.pool_1_7 = nn.MaxPool2d(2)

        self.conv_2_1 = nn.Conv2d(64, 128, 3)
        torch.nn.init.kaiming_normal_(self.conv_2_1.weight)
        self.relu_2_2 = nn.ReLU()
        self.norm_2_3 = nn.BatchNorm2d(128)
        self.conv_2_4 = nn.Conv2d(128, 128, 3)
        torch.nn.init.kaiming_normal_(self.conv_2_4.weight)
        self.relu_2_5 = nn.ReLU()
        self.norm_2_6 = nn.BatchNorm2d(128)
        self.pool_2_7 = nn.MaxPool2d(2)

        self.conv_3_1 = nn.Conv2d(128, 256, 3)
        torch.nn.init.kaiming_normal_(self.conv_3_1.weight)
        self.relu_3_2 = nn.ReLU()
        self.norm_3_3 = nn.BatchNorm2d(256)
        self.conv_3_4 = nn.Conv2d(256, 256, 3)
        torch.nn.init.kaiming_normal_(self.conv_3_4.weight)
        self.relu_3_5 = nn.ReLU()
        self.norm_3_6 = nn.BatchNorm2d(256)
        self.pool_3_7 = nn.MaxPool2d(2)

        self.conv_4_1 = nn.Conv2d(256, 512, 3)
        torch.nn.init.kaiming_normal_(self.conv_4_1.weight)
        self.relu_4_2 = nn.ReLU()
        self.norm_4_3 = nn.BatchNorm2d(512)
        self.conv_4_4 = nn.Conv2d(512, 512, 3)
        torch.nn.init.kaiming_normal_(self.conv_4_4.weight)
        self.relu_4_5 = nn.ReLU()
        self.norm_4_6 = nn.BatchNorm2d(512)
        self.dq_1 = torch.quantization.DeQuantStub()

        # deconv is the '2D transposed convolution operator'
        self.deconv_5_1 = nn.ConvTranspose2d(512, 256, (2, 2), 2)
        # 61x61 -> 48x48 crop
        self.c_crop_5_2 = lambda x: x[:, :, 6:54, 6:54]
        self.concat_5_3 = lambda x, y: torch.cat((x, y), dim=1)
        self.q_2 = torch.quantization.QuantStub()
        self.conv_5_4 = nn.Conv2d(512, 256, 3)
        torch.nn.init.kaiming_normal_(self.conv_5_4.weight)
        self.relu_5_5 = nn.ReLU()
        self.norm_5_6 = nn.BatchNorm2d(256)
        self.conv_5_7 = nn.Conv2d(256, 256, 3)
        torch.nn.init.kaiming_normal_(self.conv_5_7.weight)
        self.relu_5_8 = nn.ReLU()
        self.norm_5_9 = nn.BatchNorm2d(256)
        self.dq_2 = torch.quantization.DeQuantStub()

        self.deconv_6_1 = nn.ConvTranspose2d(256, 128, (2, 2), 2)
        # 121x121 -> 88x88 crop
        self.c_crop_6_2 = lambda x: x[:, :, 17:105, 17:105]
        self.concat_6_3 = lambda x, y: torch.cat((x, y), dim=1)
        self.q_3 = torch.quantization.QuantStub()
        self.conv_6_4 = nn.Conv2d(256, 128, 3)
        torch.nn.init.kaiming_normal_(self.conv_6_4.weight)
        self.relu_6_5 = nn.ReLU()
        self.norm_6_6 = nn.BatchNorm2d(128)
        self.conv_6_7 = nn.Conv2d(128, 128, 3)
        torch.nn.init.kaiming_normal_(self.conv_6_7.weight)
        self.relu_6_8 = nn.ReLU()
        self.norm_6_9 = nn.BatchNorm2d(128)
        self.dq_3 = torch.quantization.DeQuantStub()

        self.deconv_7_1 = nn.ConvTranspose2d(128, 64, (2, 2), 2)
        # 252x252 -> 168x168 crop
        self.c_crop_7_2 = lambda x: x[:, :, 44:212, 44:212]
        self.concat_7_3 = lambda x, y: torch.cat((x, y), dim=1)
        self.q_4 = torch.quantization.QuantStub()
        self.conv_7_4 = nn.Conv2d(128, 64, 3)
        torch.nn.init.kaiming_normal_(self.conv_7_4.weight)
        self.relu_7_5 = nn.ReLU()
        self.norm_7_6 = nn.BatchNorm2d(64)
        self.conv_7_7 = nn.Conv2d(64, 64, 3)
        torch.nn.init.kaiming_normal_(self.conv_7_7.weight)
        self.relu_7_8 = nn.ReLU()
        self.norm_7_9 = nn.BatchNorm2d(64)

        # 1x1 conv ~= fc; n_classes = 9
        self.conv_8_1 = nn.Conv2d(64, 9, 1)
        self.dq_4 = torch.quantization.DeQuantStub()

        # residual connections need to be dequantized seperately
        self.dq_resid_1 = torch.quantization.DeQuantStub()
        self.dq_resid_2 = torch.quantization.DeQuantStub()
        self.dq_resid_3 = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.q_1(x)
        x = self.conv_1_1(x)
        x = self.relu_1_2(x)
        x = self.norm_1_3(x)
        x = self.conv_1_4(x)
        x = self.relu_1_5(x)
        x_resid_1_quantized = self.norm_1_6(x)
        x = self.pool_1_7(x_resid_1_quantized)
        x_resid_1 = self.dq_resid_1(x_resid_1_quantized)

        x = self.conv_2_1(x)
        x = self.relu_2_2(x)
        x = self.norm_2_3(x)
        x = self.conv_2_4(x)
        x = self.relu_2_5(x)
        x_resid_2_quantized = self.norm_2_6(x)
        x = self.pool_2_7(x_resid_2_quantized)
        x_resid_2 = self.dq_resid_2(x_resid_2_quantized)

        x = self.conv_3_1(x)
        x = self.relu_3_2(x)
        x = self.norm_3_3(x)
        x = self.conv_3_4(x)
        x = self.relu_3_5(x)
        x_resid_3_quantized = self.norm_3_6(x)
        x = self.pool_3_7(x_resid_3_quantized)
        x_resid_3 = self.dq_resid_3(x_resid_3_quantized)

        x = self.conv_4_1(x)
        x = self.relu_4_2(x)
        x = self.norm_4_3(x)
        x = self.conv_4_4(x)
        x = self.relu_4_5(x)
        x = self.norm_4_6(x)
        x = self.dq_1(x)

        x = self.deconv_5_1(x)
        x = self.concat_5_3(self.c_crop_5_2(x_resid_3), x)
        x = self.q_2(x)
        x = self.conv_5_4(x)
        x = self.relu_5_5(x)
        x = self.norm_5_6(x)
        x = self.conv_5_7(x)
        x = self.relu_5_8(x)
        x = self.norm_5_9(x)
        x = self.dq_2(x)

        x = self.deconv_6_1(x)
        x = self.concat_6_3(self.c_crop_6_2(x_resid_2), x)
        x = self.q_3(x)
        x = self.conv_6_4(x)
        x = self.relu_6_5(x)
        x = self.norm_6_6(x)
        x = self.conv_6_7(x)
        x = self.relu_6_8(x)
        x = self.norm_6_9(x)
        x = self.dq_3(x)

        x = self.deconv_7_1(x)
        x = self.concat_7_3(self.c_crop_7_2(x_resid_1), x)
        x = self.q_4(x)
        x = self.conv_7_4(x)
        x = self.relu_7_5(x)
        x = self.norm_7_6(x)
        x = self.conv_7_7(x)
        x = self.relu_7_8(x)
        x = self.norm_7_9(x)

        x = self.conv_8_1(x)
        x = self.dq_4(x)
        return x

When I attempt to quantize this model using latest PyTorch (1.8.1):

def get_model():
    model = UNet()
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    checkpoints_dir = '/mnt/checkpoints'
    model.load_state_dict(
        torch.load(f"{checkpoints_dir}/model_50.pth", map_location=torch.device('cpu'))
    )
    model.eval()

    # NOTE(aleksey): we could potentially speed this up even more by switching from
    # conv->relu->batchnorm order to conv->batchnorm->relu order. PyTorch curently supports
    # conv->batchnorm->relu fusion *only*.
    #
    # Which placement of the relu layer is optimal is a subject of academic debate. The order
    # that the model *currently* uses seems to be the more popular option. I am not swapping the
    # order of the operations out of laziness -- but you can probably speed things up a little
    # bit more by going ahead and making that more invasive change.
    model = torch.quantization.fuse_modules(
        model,
        [
            ['conv_1_1', 'relu_1_2'],
            ['conv_1_4', 'relu_1_5'],
            ['conv_2_1', 'relu_2_2'],
            ['conv_2_4', 'relu_2_5'],
            ['conv_3_1', 'relu_3_2'],
            ['conv_3_4', 'relu_3_5'],
            ['conv_4_1', 'relu_4_2'],
            ['conv_4_4', 'relu_4_5'],
        ]
    )
    model = torch.quantization.prepare(model)
    print(f"Quantizing the model...")
    start_time = time.time()

    dataloader = get_dataloader()
    for i, (batch, segmap) in enumerate(dataloader):
        model(batch)

    model = torch.quantization.convert(model)
    print(f"Quantization done in {str(time.time() - start_time)} seconds.")
    model.eval()
    return model

I recieve the following error:

AssertionError: Per channel weight observer is not supported yet for ConvTranspose{n}d.

This error occurs because of the ConvTranspose2d layers in the model, which is not currently supported by the model quantization API. However, I’ve traced through the model by hand, and every place where ConvTranspose2d appears, I’ve carefully packed with QuantStub and DeQuantStub.

Can anyone else spot where my error with this network definition is? Or perhaps this is a bad assert, e.g. a recently introduced bug in PyTorch?

I think the error is due to the fact that we ignore the placement of quant-dequant nodes in the model in the logic that raises this error. Since the global qconfig is used, it raises the error as it expects the ConvTranspose to be quantized using same qconfig. I have filed ConvTranspose config error in prepare API · Issue #57420 · pytorch/pytorch · GitHub to track this.

To avoid this error can you try setting the qconfig of all the ConvTranspose modules to None? Hopefully, that should fix it.

yes, +1, this is the recommended solution