import torch
import torch.nn as nn
from torch.quantization import QuantStub, DeQuantStub
class QuantizableModel(nn.Module):
def __init__(self, *args, **kwargs):
super(QuantizableModel, self).__init__()
# self.module = ConvBNReLU(3, 64)
self.conv = nn.Conv2d(3, 2, 3, 1, 1, groups=1, bias=True)
self.quant = QuantStub()
self.dequant = DeQuantStub()
# weight initialization
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out')
nn.init.zeros_(self.conv.bias)
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.dequant(x)
return x
if __name__ == "__main__":
model = QuantizableModel().eval()
inp = torch.randn(1, 3, 224, 224)
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
# # Dummy calibration
model(inp)
torch.quantization.convert(model, inplace=True)
print("before traced: ", model.state_dict().keys())
traced_model = torch.jit.trace(model, inp).eval()
print("after traced: ", traced_model.state_dict().keys())
The output of running the above code in pytorch 1.5 is:
before traced: odict_keys(['conv.weight', 'conv.scale', 'conv.zero_point', 'conv.bias', 'quant.scale', 'quant.zero_point'])
after traced: odict_keys(['conv._packed_params', 'quant.scale', 'quant.zero_point'])
The output of running the same code in pytorch 1.6 is:
before traced: odict_keys(['conv.weight', 'conv.bias', 'conv.scale', 'conv.zero_point', 'quant.scale', 'quant.zero_point'])
after traced: odict_keys(['quant.scale', 'quant.zero_point'])
Parameters of the quantized model will miss in state_dict after being traced in pytorch1.6. Is it a bug or feature?