Error in QAT evaluate

When i run QAT, training is normal, but when i want to evaluate the qat model, an error length of scales must equal to channel confuse me.
I use pytorch 1.4.0, and my code is

# Traing is Normal
net = MyQATNet()
net.fuse_model()
net.qconfig = torch.quantization.get_default_qat_config("fbgemm")
net_eval = net
model.apply(torch.quantization.enable_observer)
model.apply(torch.quantization.enable_fake_quant)

# Evaluate
qat_model = copy.deepcopy(net_eval)
qat_model.to(torch.device('cpu'))
torch.quantization.convert(qat_model, inplace=True) # Error is there
qat_model.eval()

@ptrblck Can you have a look?

Anybody help me, thanks a lot?

can you paste your network definition and the input you use to run the model? it might be a problem with your input

I think i find where is the error, When i use DataParallel, it will be such an error, but when i use single gpu, will no error.

How can i save the qat trained model, when i save torch.save(qat_model.state_dict(),'qat_model.pth') or i directly save training model torch.save(net, 'net.pth'), when i want to load the pretrained qat model, for qat_model, the key is like conv1.0.activation_post_process.scale; and when net, the key have no conv1.0.activation_post_process.scale, but expected key is conv1.0.0.activation_post_process.scale, so KeyError happened. When i see the model definition, expected key is right.

I think somebody have same question like me. https://github.com/pytorch/pytorch/issues/32691

We recently fixed a bug with QAT and DataParallel, please try with pytorch nightly to see if the issue still persists. cc @Vasiliy_Kuznetsov

you mean the load_state_dict KeyError is also solved in newest version of pytorch? tks

I think somebody have the same error: Issue with Quantization

I think i know the answew:

# pytorch 1.4
#save
torch.save(net.state_dict(),'xx') # fp32 model
#load
model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
model.fuse_model()
torch.quantization.prepare_qat(net, inplace=True)
state_dict = torch.load('xx', map_location=torch.device('cpu'))
# remove module.  module
#torch.quantization.convert(net,inplace=True) # convert it to int8 model
x = torch.randn((2,3,300,300))
y = net(x)
print(y)
print('Success')
1 Like

Yes, https://github.com/pytorch/pytorch/pull/37032 fixes an error for DataParallel. @xieydd you can try the nightly to verify if this fixes your problem.

nightly can fix my problem, tks for all.

1 Like