Selective Quantization

Hi, I like to selectively quantize layers as some layers in my project just serve as a regularizer. So, I tried a few ways and got confused with the following results.

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 10)    
        self.relu1 = nn.ReLU(inplace=True)

    def forward(self, x):             
        return self.relu1(self.l1(x.view(x.size(0), -1)))  

1. selective qconfig assignment and top level transform

model = LeNet()
model.l1.qconfig = torch.quantization.get_default_qat_qconfig()
torch.quantization.prepare_qat(model, inplace=True)
print(model)


2. selective qconfig assignment and selective transform

model2 = LeNet()
model2.l1.qconfig = torch.quantization.get_default_qat_qconfig()
torch.quantization.prepare_qat(model2.l1, inplace=True)
print(model2)

You can see that the 2nd case doesn’t have (weight_fake_quant): FakeQuantize. Is this a correct behavior? Shouldn’t both yield the same transformed model?
Also, if there is a better way to do selective quantization (like different bits, quant vs no quant), please advise.

You can use the model.layer.qconfig = None syntax to turn off quantization for a layer and all of its children. Please feel free to see https://pytorch.org/docs/stable/quantization.html#model-preparation-for-quantization for more context.

Right, so in my both examples, qconifg is set for only ‘l1’ layer. But, the question is ‘should I always do prepare_qat at the top level?’. why does calling prepare_qat for a particular layer one at a time yields difference? Any explanation/insight would be appreciated

I think the reason is prepare_qat() calls convert(), which doesn’t convert the root module, so if you print the type of l1, in case 2 model2.l1 is root module thus not converted and still has type <class ‘torch.nn.modules.linear.Linear’> which doesn’t have weight_fake_quant attribute, while model.l1 is type <class ‘torch.nn.qat.modules.linear.Linear’> which has weight_fake_quant attribute.

Sounds you’re right: it doesn’t convert the very model given to prepare_qat()

print(type(model2.l1))
<class 'torch.nn.modules.linear.Linear'>