I´m experiencing the same behaviour using QAT from scratch or loading a pretrained model.
qat_model = MLP(Ni=1, Nh=256, No=1,)
The MLP consist of:
QuantStub()
Linear()
DeQuantstub()
A non implemented activation function i try.
– repeated 5 times.
Load Model
qat_model.load_state_dict(torch.load(‘dict/model_state_dict.pt’, map_location=‘cpu’))
qat_model.train()
activation_bitwidth = 8 #whatever bit you want
bitwidth = 8 #whatever bit you want
fq_activation = torch.quantization.FakeQuantize.with_args(observer=torch.quantization.MinMaxObserver.with_args(
quant_min=0,
quant_max=2**activation_bitwidth-1,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,))
fq_weights = torch.quantization.FakeQuantize.with_args(
observer = torch.quantization.MinMaxObserver.with_args(
quant_min=-(2 ** bitwidth) // 2,
quant_max=(2 ** bitwidth) // 2 - 1,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
reduce_range=False,))
intB_qat_qconfig = torch.quantization.QConfig(activation= fq_activation,weight = fq_weights)
qat_model.qconfig = intB_qat_qconfig
torch.ao.quantization.prepare_qat(qat_model, inplace=True)
The loss function is similar with and without the quat_model.load_state_dict, it feels like the weighs are overwritten.