Hi! I am using torch 2.0.0 and the same issue appeared. I have a custom architecture based on transformer model (Attention + FeedForward). I quantized my model and saved it successfully. However, when I tried to load the weights, it gave me the same error. (I also checked with torch==2.2.2 and torch==2.3.0)
Preparing model for weights loading
model.eval()
model.transformer.qconfig = get_default_qat_qconfig(args.quantize_engine)
model.transformer.tok_embeddings.qconfig = float_qparams_weight_only_qconfig
torch.quantization.prepare(model.transformer, inplace=True)
torch.quantization.convert(model.transformer, inplace=True)
Weights Loading
model_path = "./models/test_model.pth"
# Ensure the state_dict is an OrderedDict
state_dict = torch.load(model_path, map_location='cpu')
res = model.load_state_dict(state_dict, strict=False)
Error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[15], line 6
3 # Ensure the state_dict is an OrderedDict
4 state_dict = torch.load(model_path, map_location='cpu')
----> 6 res = model.load_state_dict(state_dict, strict=True)
File ~/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/torch/nn/modules/module.py:2139, in Module.load_state_dict(self, state_dict, strict, assign)
2132 out = hook(module, incompatible_keys)
2133 assert out is None, (
2134 "Hooks registered with ``register_load_state_dict_post_hook`` are not"
2135 "expected to return new values, if incompatible_keys need to be modified,"
2136 "it should be done inplace."
2137 )
-> 2139 load(self, state_dict)
2140 del load
2142 if strict:
File ~/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/torch/nn/modules/module.py:2127, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
2125 child_prefix = prefix + name + '.'
2126 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2127 load(child, child_state_dict, child_prefix)
2129 # Note that the hook can modify missing_keys and unexpected_keys.
2130 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File ~/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/torch/nn/modules/module.py:2127, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
2125 child_prefix = prefix + name + '.'
2126 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2127 load(child, child_state_dict, child_prefix)
2129 # Note that the hook can modify missing_keys and unexpected_keys.
2130 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
[... skipping similar frames: Module.load_state_dict.<locals>.load at line 2127 (2 times)]
File ~/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/torch/nn/modules/module.py:2127, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
2125 child_prefix = prefix + name + '.'
2126 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2127 load(child, child_state_dict, child_prefix)
2129 # Note that the hook can modify missing_keys and unexpected_keys.
2130 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File ~/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/torch/nn/modules/module.py:2121, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
2119 if assign:
2120 local_metadata['assign_to_params_buffers'] = assign
-> 2121 module._load_from_state_dict(
2122 local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
2123 for name, child in module._modules.items():
2124 if child is not None:
File ~/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/torch/ao/nn/quantized/modules/linear.py:220, in Linear._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
216 version = local_metadata.get('version', None)
218 if version is None or version == 1:
219 # We moved the parameters into a LinearPackedParameters submodule
--> 220 weight = state_dict.pop(prefix + 'weight')
221 bias = state_dict.pop(prefix + 'bias')
222 state_dict.update({prefix + '_packed_params.weight': weight,
223 prefix + '_packed_params.bias': bias})
KeyError: 'transformer.layers.0.attention.wq.weight'
Model State Dict:
transformer.quant.scale
transformer.quant.zero_point
transformer.tok_embeddings._packed_params.dtype
transformer.tok_embeddings._packed_params._packed_weight
transformer.layers.0.quant.scale
transformer.layers.0.quant.zero_point
transformer.layers.0.attention.quant.scale
transformer.layers.0.attention.quant.zero_point
transformer.layers.0.attention.wq.scale
transformer.layers.0.attention.wq.zero_point
transformer.layers.0.attention.wq._packed_params.dtype
transformer.layers.0.attention.wq._packed_params._packed_params
transformer.layers.0.attention.wk.scale
transformer.layers.0.attention.wk.zero_point
transformer.layers.0.attention.wk._packed_params.dtype
transformer.layers.0.attention.wk._packed_params._packed_params
transformer.layers.0.attention.wv.scale
transformer.layers.0.attention.wv.zero_point
transformer.layers.0.attention.wv._packed_params.dtype
transformer.layers.0.attention.wv._packed_params._packed_params
transformer.layers.0.attention.wo.scale
transformer.layers.0.attention.wo.zero_point
transformer.layers.0.attention.wo._packed_params.dtype
transformer.layers.0.attention.wo._packed_params._packed_params
transformer.layers.0.feed_forward.quant.scale
transformer.layers.0.feed_forward.quant.zero_point
transformer.layers.0.feed_forward.ff.scale
transformer.layers.0.feed_forward.ff.zero_point
transformer.layers.0.feed_forward.w1.scale
transformer.layers.0.feed_forward.w1.zero_point
transformer.layers.0.feed_forward.w1._packed_params.dtype
transformer.layers.0.feed_forward.w1._packed_params._packed_params
transformer.layers.0.feed_forward.w2.scale
transformer.layers.0.feed_forward.w2.zero_point
transformer.layers.0.feed_forward.w2._packed_params.dtype
transformer.layers.0.feed_forward.w2._packed_params._packed_params
transformer.layers.0.feed_forward.w3.scale
transformer.layers.0.feed_forward.w3.zero_point
transformer.layers.0.feed_forward.w3._packed_params.dtype
transformer.layers.0.feed_forward.w3._packed_params._packed_params
transformer.layers.0.attention_norm.weight
transformer.layers.0.ffn_norm.weight
transformer.layers.0.ff.scale
transformer.layers.0.ff.zero_point
transformer.layers.0.adapter_after_attention.quant.scale
transformer.layers.0.adapter_after_attention.quant.zero_point
transformer.layers.0.adapter_after_attention.down1.scale
transformer.layers.0.adapter_after_attention.down1.zero_point
transformer.layers.0.adapter_after_attention.down1._packed_params.dtype
transformer.layers.0.adapter_after_attention.down1._packed_params._packed_params
transformer.layers.0.adapter_after_attention.up1.scale
transformer.layers.0.adapter_after_attention.up1.zero_point
transformer.layers.0.adapter_after_attention.up1._packed_params.dtype
transformer.layers.0.adapter_after_attention.up1._packed_params._packed_params
transformer.norm.weight
classifier.fc_layers.0.weight
classifier.fc_layers.0.bias
classifier.output_layer.weight
classifier.output_layer.bias
Keys in State Dict from Memory:
transformer.quant.scale
transformer.quant.zero_point
transformer.tok_embeddings._packed_params.dtype
transformer.tok_embeddings._packed_params._packed_weight
transformer.layers.0.quant.scale
transformer.layers.0.quant.zero_point
transformer.layers.0.attention.quant.scale
transformer.layers.0.attention.quant.zero_point
transformer.layers.0.attention.wq.scale
transformer.layers.0.attention.wq.zero_point
transformer.layers.0.attention.wq._packed_params.dtype
transformer.layers.0.attention.wq._packed_params._packed_params
transformer.layers.0.attention.wk.scale
transformer.layers.0.attention.wk.zero_point
transformer.layers.0.attention.wk._packed_params.dtype
transformer.layers.0.attention.wk._packed_params._packed_params
transformer.layers.0.attention.wv.scale
transformer.layers.0.attention.wv.zero_point
transformer.layers.0.attention.wv._packed_params.dtype
transformer.layers.0.attention.wv._packed_params._packed_params
transformer.layers.0.attention.wo.scale
transformer.layers.0.attention.wo.zero_point
transformer.layers.0.attention.wo._packed_params.dtype
transformer.layers.0.attention.wo._packed_params._packed_params
transformer.layers.0.feed_forward.quant.scale
transformer.layers.0.feed_forward.quant.zero_point
transformer.layers.0.feed_forward.ff.scale
transformer.layers.0.feed_forward.ff.zero_point
transformer.layers.0.feed_forward.w1.scale
transformer.layers.0.feed_forward.w1.zero_point
transformer.layers.0.feed_forward.w1._packed_params.dtype
transformer.layers.0.feed_forward.w1._packed_params._packed_params
transformer.layers.0.feed_forward.w2.scale
transformer.layers.0.feed_forward.w2.zero_point
transformer.layers.0.feed_forward.w2._packed_params.dtype
transformer.layers.0.feed_forward.w2._packed_params._packed_params
transformer.layers.0.feed_forward.w3.scale
transformer.layers.0.feed_forward.w3.zero_point
transformer.layers.0.feed_forward.w3._packed_params.dtype
transformer.layers.0.feed_forward.w3._packed_params._packed_params
transformer.layers.0.attention_norm.weight
transformer.layers.0.ffn_norm.weight
transformer.layers.0.ff.scale
transformer.layers.0.ff.zero_point
transformer.layers.0.adapter_after_attention.quant.scale
transformer.layers.0.adapter_after_attention.quant.zero_point
transformer.layers.0.adapter_after_attention.down1.scale
transformer.layers.0.adapter_after_attention.down1.zero_point
transformer.layers.0.adapter_after_attention.down1._packed_params.dtype
transformer.layers.0.adapter_after_attention.down1._packed_params._packed_params
transformer.layers.0.adapter_after_attention.up1.scale
transformer.layers.0.adapter_after_attention.up1.zero_point
transformer.layers.0.adapter_after_attention.up1._packed_params.dtype
transformer.layers.0.adapter_after_attention.up1._packed_params._packed_params
transformer.norm.weight
classifier.fc_layers.0.weight
classifier.fc_layers.0.bias
classifier.output_layer.weight
classifier.output_layer.bias
Versions
PyTorch version: 2.0.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Amazon Linux 2 (x86_64)
GCC version: (GCC) 7.3.1 20180712 (Red Hat 7.3.1-17)
Clang version: Could not collect
CMake version: version 3.29.2
Libc version: glibc-2.26
Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.10.214-202.855.amzn2.x86_64-x86_64-with-glibc2.26
Is CUDA available: N/A
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: NVIDIA A10G
GPU 1: NVIDIA A10G
GPU 2: NVIDIA A10G
GPU 3: NVIDIA A10G
Nvidia driver version: 535.161.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 48
On-line CPU(s) list: 0-47
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 1
NUMA node(s): 1
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7R32
Stepping: 0
CPU MHz: 2428.708
BogoMIPS: 5599.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 512K
L3 cache: 16384K
NUMA node0 CPU(s): 0-47
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid
Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] numpydoc==1.6.0
[conda] blas 2.120 mkl conda-forge
[conda] blas-devel 3.9.0 20_linux64_mkl conda-forge
[conda] libblas 3.9.0 20_linux64_mkl conda-forge
[conda] libcblas 3.9.0 20_linux64_mkl conda-forge
[conda] liblapack 3.9.0 20_linux64_mkl conda-forge
[conda] liblapacke 3.9.0 20_linux64_mkl conda-forge
[conda] mkl 2023.2.0 h84fe81f_50496 conda-forge
[conda] mkl-devel 2023.2.0 ha770c72_50496 conda-forge
[conda] mkl-include 2023.2.0 h84fe81f_50496 conda-forge
[conda] mkl-service 2.4.1 py310hc72dfd8_0 conda-forge
[conda] mkl_fft 1.3.8 py310ha3dbc2a_1 conda-forge
[conda] numexpr 2.9.0 mkl_py310hc8c826e_0 conda-forge
[conda] numpy 1.26.4 pypi_0 pypi
[conda] numpydoc 1.6.0 pyhd8ed1ab_0 conda-forge