Broken disable_fake_quant

It seems disable_fake_quant and other similar functions not working for the following case (Pytorch 1.7).

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 10)    
        self.relu1 = nn.ReLU(inplace=True)

    def forward(self, x):             
        return self.relu1(self.l1(x.view(x.size(0), -1)))  

fq_weight = torch.quantization.FakeQuantize.with_args(\
    observer=torch.quantization.MovingAverageMinMaxObserver.with_args(),
    quant_min=0, quant_max=255, dtype=torch.quint8)   

fq_activation = torch.quantization.FakeQuantize.with_args(\
    observer=torch.quantization.MovingAverageMinMaxObserver.with_args(),
    quant_min=0, quant_max=255, dtype=torch.quint8)   


model = LeNet()
model.l1.qconfig = torch.quantization.QConfig(activation=fq_activation, weight=fq_weight)     
torch.quantization.prepare_qat(model, inplace=True)
model.l1.apply(torch.quantization.disable_fake_quant)


The above case shows it works just fine. Now, I make my own FakeQuantize as follows, then it doesn’t work any more.

class MyFakeQuantize(torch.quantization.FakeQuantize):
    def __init__(self, observer, quant_min, quant_max, n_cluster=0, **observer_kwargs):
        super().__init__(observer, quant_min, quant_max, **observer_kwargs)

fq_weight = MyFakeQuantize.with_args(\
    observer=torch.quantization.MovingAverageMinMaxObserver.with_args(),
    quant_min=0, quant_max=255, dtype=torch.quint8)   

fq_activation = MyFakeQuantize.with_args(\
    observer=torch.quantization.MovingAverageMinMaxObserver.with_args(),
    quant_min=0, quant_max=255, dtype=torch.quint8)   


model2 = LeNet()
model2.l1.qconfig = torch.quantization.QConfig(activation=fq_activation, weight=fq_weight)     
torch.quantization.prepare_qat(model2, inplace=True)
model2.l1.apply(torch.quantization.disable_fake_quant)

Is this expected behavior?

hi @thyeros, I cannot reproduce this on master. Could you check if you reproduce this on v1.7 or master? Which version did you originally see the issue on?

This was with 1.7.0, plz see the snapshot.

hi @thyeros, unfortunately we cannot repro this on master and this is not a known KP. Could you try the debug script in gist:61ac9744858509e175d4ce50258782e4 · GitHub and narrow down which exact part of disable_fake_quant is not working in your environment? In the debug script, I just copied the disable_fake_quant definition so it’s easy to debug.

Hi, Vasiliy:

First of all, thanks for the help in this matter. Really appreciated. The code from you doesn’t work out of the box, as FakeQuantizeBase is not available.

Since I know that FakeQuantize is based on FakeQuantizeBase, it looks strange to me as well, hence I checked out the torch/quantization/fake_quantize.py and found out that the pytorch 1.7.0 in my image doesn’t have FakeQuantizeBase and has very different implementations for disable/enable_quant/observer.


image

Can you tell if this is the right one for pytorch1.7.0??

Thanks for the additional info. Looks like what you are describing was not supported, and added recently with https://github.com/pytorch/pytorch/pull/48072. Before that PR, the type(mod) == FakeQuantize check would not pass for custom fake quant classes. It should pass now. Could you try updating your PyTorch installation and see if it works with 1.7.1 or master?

Looking at Releases · pytorch/pytorch · GitHub, it doesn’t look like this PR was in 1.7.1. So, you could try master. Or, you could write your own function which disables fake_quants (for example, by copying the code after the PR above) and call that, without having to update your PyTorch installation.

Thanks Vasiliy for double confirming the latest PR. I will perhaps live with a custom implementation for now (like yours) wait for the official release. Thanks again!

1 Like