I used the from quantize_fx module to quantify my own model, the code is:
if name == "main":
checkpoint1 = torch.load('/mnt/vox-cpk.pth.tar')
with open('config/vox-256.yaml') as f:
config = yaml.safe_load(f)
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
kp_detector.to(cuda_device)
kp_detector.load_state_dict(checkpoint1['kp_detector'])
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
generator.to(cuda_device)
generator.load_state_dict(checkpoint1['generator'])
train_params = config['train_params']
model_fp32 = GeneratorFullModel(kp_detector, generator, None, train_params)
model_to_quantize = copy.deepcopy(model_fp32)
model_to_quantize.eval()
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
# calibrate(prepared_model, data_loader)
quantized_model = convert_fx(prepared_model)
print(quantized_model)
the GeneratorFullModel is consist of two separate models, and the error is:
- Traceback (most recent call last):
- File “/home/lab239-5/users/wangxin/first-order-model-master/ccc.py”, line 115, in
- quantized_model = convert_fx(prepared_model)
- File “/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py”, line 658, in convert_fx
- return _convert_fx(
- File “/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py”, line 563, in _convert_fx
- quantized = convert(
- File “/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/fx/convert.py”, line 754, in convert
- model = lower_to_fbgemm(model, qconfig_map, node_name_to_scope)
- File “/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py”, line 14, in lower_to_fbgemm
- return _lower_to_native_backend(model, qconfig_map, node_name_to_scope)
- File “/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py”, line 958, in _lower_to_native_backend
- _lower_static_weighted_ref_functional(model, qconfig_map)
- File “/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py”, line 606, in _lower_static_weighted_ref_functional
- (q_node, relu_node, func_node) = _match_static_pattern(
- File “/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py”, line 450, in _match_static_pattern
- assert i < len(ref_node.args),
- AssertionError: Dequantize index 1 exceeded reference node’s arg length 1