Loading Quantized Model from State_Dict with Version==None

Hi PyTorch community,

TLDR; DistilBert’s nn.quantized.Linear encounters KeyError when loading from state_dict. Saving from state_dict uses version 3 format, but loading evaluates local_metadata.get('version', None) == None which defaults to using version 1 format.

I have a problem with loading DistilBert classifier. I would load it from a pre-trained model, fine-tune it, quantize it, then save its state_dict. The issue happens when saving and reloading this quantized version. When DynamicQuantizedLinear generates keys, it uses this format:

# Version 3
    #   self
    #   |--- _packed_params : (Tensor, Tensor) representing (weight, bias)
    #                         of LinearPackedParams
    #   |--- dtype : torch.dtype

Printing the state_dict in that key:

# print(state_dict['_distilbert.transformer.layer.0.attention.q_lin._packed_params.weight'])
tensor([[ 0.0357,  0.0365,  0.0119,  ..., -0.0230,  0.0199,  0.0397],
        [ 0.0119, -0.0349,  0.0048,  ...,  0.0294, -0.0127, -0.0119],
        [ 0.0540,  0.0159, -0.0032,  ...,  0.0008, -0.0183, -0.0016],
        [ 0.0064, -0.0079,  0.0302,  ..., -0.0199,  0.0008, -0.0095],
        [ 0.0024, -0.0056,  0.0183,  ...,  0.0008,  0.0175,  0.0270],
        [-0.0024, -0.0119, -0.0238,  ...,  0.0294,  0.0199,  0.0175]],
       size=(768, 768), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.0007942558731883764,

However, when loading the model using the same Python environment and on the same machine, the de-serialization fails with the following error:

File "/home/.venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 827, in load
    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
  File "/home/.venv/lib/python3.6/site-packages/torch/nn/quantized/modules/linear.py", line 207, in _load_from_state_dict
    weight = state_dict.pop(prefix + 'weight')
KeyError: '_distilbert.transformer.layer.0.attention.q_lin.weight'

Here’s what the de-serialization method that fails looks like:

# file: torch/nn/quantized/modules/linear.py 
    # ===== Deserialization methods =====
    # Counterpart to the serialization methods, we must pack the serialized QTensor
    # weight into its packed format for use by the FBGEMM ops.
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        self.scale = float(state_dict[prefix + 'scale'])
        state_dict.pop(prefix + 'scale')

        self.zero_point = int(state_dict[prefix + 'zero_point'])
        state_dict.pop(prefix + 'zero_point')

        version = local_metadata.get('version', None)
        if version is None or version == 1:
            # We moved the parameters into a LinearPackedParameters submodule
            weight = state_dict.pop(prefix + 'weight')
            bias = state_dict.pop(prefix + 'bias')
            state_dict.update({prefix + '_packed_params.weight': weight,
                               prefix + '_packed_params.bias': bias})

        super(Linear, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
                                                  missing_keys, unexpected_keys, error_msgs)

The issue seems to be that the backward compatibility if-statement defaults to assuming the model was serialized using an earlier version if version is None. This fails in my case. Changing if version is None or version == 1: to if version == 1: fixes the issue for me but I’d like a more sustainable solution.

How do I make sure my model’s version evaluates to the correct value?

Thanks in advance for any help!

hi @salimmj,

if you are seeing this on a recent version of PyTorch (v1.5 or nightlies), would you mind filing a github issue?

for a quick local fix, you can also modify the checkpoint data. Here is a code snippet (for a different case) which is doing something similar:

    def adjust_convbn_metadata(mod, prefix, old_state_dict):
        for name, child in mod._modules.items():
            new_prefix = prefix + '.' + name if prefix != '' else name
            if isinstance(child, torch.nn.intrinsic.qat.ConvBn2d):
                old_state_dict._metadata[new_prefix]['version'] = 2
            adjust_convbn_metadata(child, new_prefix, old_state_dict)
    adjust_convbn_metadata(model, '', checkpoint['model'])
1 Like

Hi @Vasiliy_Kuznetsov,

Thanks for your help, I was actually just working on this. I ended up fixing it like this:

# This is a temporary fix for https://discuss.pytorch.org/t/loading-quantized-model-from-state-dict-with-version-none/89042
model_checkpoint['state_dict'] = OrderedDict(model_checkpoint['state_dict'])
if not hasattr(model_checkpoint['state_dict'], '_metadata'):
     setattr(model_checkpoint['state_dict'], '_metadata', OrderedDict({'version': 2}))

Your code seems more specific, I wonder if mine could break. For now I only really need to change the version for the Linear layer so I don’t know if doing it like this is going to break something else.

I will file a Github issue!

1 Like