Hi PyTorch community,
TLDR; DistilBert’s nn.quantized.Linear
encounters KeyError
when loading from state_dict
. Saving from state_dict
uses version 3
format, but loading evaluates local_metadata.get('version', None) == None
which defaults to using version 1
format.
I have a problem with loading DistilBert classifier. I would load it from a pre-trained model, fine-tune it, quantize it, then save its state_dict
. The issue happens when saving and reloading this quantized version. When DynamicQuantizedLinear generates keys, it uses this format:
_distilbert.transformer.layer.0.attention.q_lin._packed_params.weight
# Version 3
# self
# |--- _packed_params : (Tensor, Tensor) representing (weight, bias)
# of LinearPackedParams
# |--- dtype : torch.dtype
Printing the state_dict in that key:
# print(state_dict['_distilbert.transformer.layer.0.attention.q_lin._packed_params.weight'])
tensor([[ 0.0357, 0.0365, 0.0119, ..., -0.0230, 0.0199, 0.0397],
[ 0.0119, -0.0349, 0.0048, ..., 0.0294, -0.0127, -0.0119],
[ 0.0540, 0.0159, -0.0032, ..., 0.0008, -0.0183, -0.0016],
...,
[ 0.0064, -0.0079, 0.0302, ..., -0.0199, 0.0008, -0.0095],
[ 0.0024, -0.0056, 0.0183, ..., 0.0008, 0.0175, 0.0270],
[-0.0024, -0.0119, -0.0238, ..., 0.0294, 0.0199, 0.0175]],
size=(768, 768), dtype=torch.qint8,
quantization_scheme=torch.per_tensor_affine, scale=0.0007942558731883764,
zero_point=0)
However, when loading the model using the same Python environment and on the same machine, the de-serialization fails with the following error:
File "/home/.venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 827, in load
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
File "/home/.venv/lib/python3.6/site-packages/torch/nn/quantized/modules/linear.py", line 207, in _load_from_state_dict
weight = state_dict.pop(prefix + 'weight')
KeyError: '_distilbert.transformer.layer.0.attention.q_lin.weight'
Here’s what the de-serialization method that fails looks like:
# file: torch/nn/quantized/modules/linear.py
# ===== Deserialization methods =====
# Counterpart to the serialization methods, we must pack the serialized QTensor
# weight into its packed format for use by the FBGEMM ops.
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
self.scale = float(state_dict[prefix + 'scale'])
state_dict.pop(prefix + 'scale')
self.zero_point = int(state_dict[prefix + 'zero_point'])
state_dict.pop(prefix + 'zero_point')
version = local_metadata.get('version', None)
if version is None or version == 1:
# We moved the parameters into a LinearPackedParameters submodule
weight = state_dict.pop(prefix + 'weight')
bias = state_dict.pop(prefix + 'bias')
state_dict.update({prefix + '_packed_params.weight': weight,
prefix + '_packed_params.bias': bias})
super(Linear, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
The issue seems to be that the backward compatibility if-statement defaults to assuming the model was serialized using an earlier version if version is None
. This fails in my case. Changing if version is None or version == 1:
to if version == 1:
fixes the issue for me but I’d like a more sustainable solution.
How do I make sure my model’s version evaluates to the correct value?
Thanks in advance for any help!