Static quatization example not quite working as expected

I have a quantization script (link to full code here). Quoting from the code, here are the relevant model layers:

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant_1 = torch.quantization.QuantStub()
        self.conv_1_1 = nn.Conv2d(3, 64, 3)
        torch.nn.init.kaiming_normal_(self.conv_1_1.weight)
        self.relu_1_2 = nn.ReLU()
        self.norm_1_3 = nn.BatchNorm2d(64)
        self.dequant_1 = torch.quantization.DeQuantStub()
        self.conv_1_4 = nn.Conv2d(64, 64, 3)
        torch.nn.init.kaiming_normal_(self.conv_1_4.weight)
        # continued...

    def forward(self, x):
        x = self.quant_1(x)        
        x = self.conv_1_1(x)
        x = self.relu_1_2(x)
        x = self.norm_1_3(x)
        x = self.dequant_1(x)
        x = self.conv_1_4(x)
        x = self.relu_1_5(x)
        # continued...

I am attempting to quantize the model, then evaluate its performance on a dataset of interest:

    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    checkpoints_dir = '/spell/checkpoints'
    model.load_state_dict(
        torch.load(f"{checkpoints_dir}/model_50.pth", map_location=torch.device('cpu'))
    )
    model.eval()
    
    # NEW
    model = torch.quantization.prepare(model)
    print(f"Quantizing the model...")
    start_time = time.time()
    
    for i, (batch, segmap) in enumerate(dataloader):
        # batch = batch.cuda()
        # segmap = segmap.cuda()
        model(batch)

    model = torch.quantization.convert(model)
    print(f"Quantization done in {str(time.time() - start_time)} seconds.")

    print(f"Evaluating the model...")
    start_time = time.time()
    for i, (batch, segmap) in enumerate(dataloader):
        # batch = batch.cuda()
        # segmap = segmap.cuda()
        model(batch)
    
    print(f"Evaluation done in {str(time.time() - start_time)} seconds.")

The code, as written, fails with the following log output:

Loading the model...
Quantizing the model...
Quantization done in 84.06090354919434 seconds.
Evaluating the model...
Traceback (most recent call last):
  File "/spell/servers/eval_quantized.py", line 277, in <module>
    main()
  File "/spell/servers/eval_quantized.py", line 269, in main
    model(batch)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/spell/servers/eval_quantized.py", line 178, in forward
    x = self.conv_1_4(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/quantized/modules/conv.py", line 215, in forward
    self.dilation, self.groups, self.scale, self.zero_point)
RuntimeError: Could not run 'quantized::conv2d' with arguments from the 'CPUTensorId' backend. 'quantized::conv2d' is only available for these backends: [QuantizedCPUTensorId].

This implies that x = self.conv_1_4(x) is wrong, because conv_1_4 is quantized. But I don’t understand why conv_1_4 is quantized, though.

(apologies for posting an incomplete question–hit the wrong button)

If the intent is to dequantize and want to do conv_1_4 in fp32, then the problem is that by default, the quantization APIs quantize all convolutions in the model. The workaround would be to disable quantization for conv_1_4, like this:

model.qconfig = ...
# disable quant for a specific layer
model.conv_1_4.qconfig = None
# continue with the quantization APIs, conv_1_4 will not be quantized
1 Like

by default, the quantization APIs quantize all convolutions in the model

Ah! This was not obvious from reading the documentation; the Quantization quickstart does not mention this behavior, nor, as best I can tell, does the Static quantization tutorial.

May I suggest making a docs task for this? Sorry, I know this is my second one in as many days. :smile: