Dynamic range quantization for HF models seem to be spurious

I tried RoBERTa-Base and BERT-Base with random inputs. In both cases, the outputs of the dynamic-range quantized models don’t produce close outputs as the original ones.

Here’s the Colab Notebook to reproduce the problem: Google Colab

Any solutions?

Hi Sayak,

I see you’re using torch.quantization.quantize_dynamic. If you use static quantization (torch.quantization.quantize), do you see the same result? One potential issue is PyTorch quantization currently does not have great support for these models, since by default we’re only quantizing the linear layers, not the attention layers. Another thing to try is to use quantization aware training (QAT), which often improves accuracies of quantization. You can find more information about this here.

Best,
-Andrew

Thanks, Andrew.

When attempting static quantization, I get:

AssertionError: Embedding quantization is only supported with float_qparams_weight_only_qconfig.

So, I tried:

roberta_model.qconfig = (
    torch.quantization.qconfig.float_qparams_weight_only_qconfig
)
roberta_quant_model = torch.quantization.prepare(roberta_model)

After calibration, when I finally called:

roberta_quant_model = torch.quantization.convert(roberta_quant_model)

It ran into:

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
Input In [19], in <cell line: 3>()
      1 from tqdm.auto import tqdm
----> 3 roberta_quant_model = torch.quantization.convert(roberta_quant_model)
      5 for batch in tqdm(test_loader):
      6     with torch.no_grad():

File ~/.local/bin/.virtualenvs/plm/lib/python3.8/site-packages/torch/ao/quantization/quantize.py:505, in convert(module, mapping, inplace, remove_qconfig, convert_custom_config_dict)
    503 if not inplace:
    504     module = copy.deepcopy(module)
--> 505 _convert(
    506     module, mapping, inplace=True,
    507     convert_custom_config_dict=convert_custom_config_dict)
    508 if remove_qconfig:
    509     _remove_qconfig(module)

File ~/.local/bin/.virtualenvs/plm/lib/python3.8/site-packages/torch/ao/quantization/quantize.py:541, in _convert(module, mapping, inplace, convert_custom_config_dict)
    536 for name, mod in module.named_children():
    537     # both fused modules and observed custom modules are
    538     # swapped as one unit
    539     if not isinstance(mod, _FusedModule) and \
    540        type(mod) not in custom_module_class_mapping:
--> 541         _convert(mod, mapping, True,  # inplace
    542                  convert_custom_config_dict)
    543     reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
    545 for key, value in reassign.items():

File ~/.local/bin/.virtualenvs/plm/lib/python3.8/site-packages/torch/ao/quantization/quantize.py:541, in _convert(module, mapping, inplace, convert_custom_config_dict)
    536 for name, mod in module.named_children():
    537     # both fused modules and observed custom modules are
    538     # swapped as one unit
    539     if not isinstance(mod, _FusedModule) and \
    540        type(mod) not in custom_module_class_mapping:
--> 541         _convert(mod, mapping, True,  # inplace
    542                  convert_custom_config_dict)
    543     reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
    545 for key, value in reassign.items():

File ~/.local/bin/.virtualenvs/plm/lib/python3.8/site-packages/torch/ao/quantization/quantize.py:543, in _convert(module, mapping, inplace, convert_custom_config_dict)
    539     if not isinstance(mod, _FusedModule) and \
    540        type(mod) not in custom_module_class_mapping:
    541         _convert(mod, mapping, True,  # inplace
    542                  convert_custom_config_dict)
--> 543     reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
    545 for key, value in reassign.items():
    546     module._modules[key] = value

File ~/.local/bin/.virtualenvs/plm/lib/python3.8/site-packages/torch/ao/quantization/quantize.py:568, in swap_module(mod, mapping, custom_module_class_mapping)
    566     swapped = True
    567 elif type(mod) in mapping:
--> 568     new_mod = mapping[type(mod)].from_float(mod)
    569     swapped = True
    571 if swapped:
    572     # Preserve module's pre forward hooks. They'll be called on quantized input

File ~/.local/bin/.virtualenvs/plm/lib/python3.8/site-packages/torch/nn/quantized/modules/normalization.py:34, in LayerNorm.from_float(cls, mod)
     32 @classmethod
     33 def from_float(cls, mod):
---> 34     scale, zero_point = mod.activation_post_process.calculate_qparams()
     35     new_mod = cls(
     36         mod.normalized_shape, mod.weight, mod.bias, float(scale),
     37         int(zero_point), mod.eps, mod.elementwise_affine)
     38     return new_mod

File ~/.local/bin/.virtualenvs/plm/lib/python3.8/site-packages/torch/ao/quantization/observer.py:1256, in PlaceholderObserver.calculate_qparams(self)
   1254 @torch.jit.export
   1255 def calculate_qparams(self):
-> 1256     raise Exception(
   1257         "calculate_qparams should not be called for PlaceholderObserver"
   1258     )

Exception: calculate_qparams should not be called for PlaceholderObserver

@andrewor any update?

I have also tried something like:

backend = "fbgemm"

qconfig_dict = {
    torch.nn.Embedding: float_qparams_weight_only_qconfig,
    torch.nn.Linear: torch.quantization.get_default_qconfig(backend)
}

torch.quantization.propagate_qconfig_(roberta_model, qconfig_dict)
torch.quantization.prepare(roberta_model, inplace=True)

And then after calibration, when I call convert, it results into:

File ~/.local/bin/.virtualenvs/plm/lib/python3.8/site-packages/torch/nn/quantized/modules/embedding_ops.py:112, in Embedding.forward(self, indices)
    111 def forward(self, indices: Tensor) -> Tensor:
--> 112     return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)

RuntimeError: Expect weight, indices, and offsets to be contiguous.

The roberta_model is defined like so (following the recommendation of this guide):

class QuantizedRobertaForSequenceClassification(nn.Module):
    def __init__(self, roberta_model):
        super().__init__()
        self.roberta_model = roberta_model.eval()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
    ):
        if not input_ids.is_contiguous():
            input_ids = input_ids.contiguous()
        if not attention_mask.is_contiguous():
            attention_mask = attention_mask.contiguous()
        if labels is not None:
            if not labels.is_contiguous():
                labels = labels.contiguous()
        
        input_dict = {"input_ids": self.quant(input_ids)}
        input_dict.update({"attention_mask": self.quant(attention_mask)})
        if labels is not None:
            input_dict.update({"labels": self.quant(labels)})

        outputs = self.roberta_model(**input_dict)
        x = self.dequant(outputs.logits)
        return x