Static quantization/QAT change shape of embedding layers

Hello,
I am trying to apply different kinds of quantization (static, dynamic and quantization-aware trainings) to a BERT model taken from the transformers library. My model works as expected with no quantization or with dynamic quantization, but both static quantization and quantization-aware training return a strange error which makes it seem as though QuantStub and DeQuantStub are changing tensor shapes.

Here is a simplified version, the comments indicate where the errors are and the reasoning for adding the various tweaks. I included the two error stacks described in the comments (error A and B) at the end this post.

def __init__(self, ...):
    ...
    self.model = AutoModelForTokenClassification.from_pretrained('bert-base-multilingual-cased', config=...)
    ...

def train(self, trainset, devset, num_epochs, **kwargs):
    self.trainer = Trainer(model=self.model, args=..., train_dataset=trainset, eval_dataset=devset)
    
    # Without any kind of quantization, the model works as expected
    
    # Case 1: dynamic quantization for Linear layers, works as expected
    if self.quantization == "dynamic":
        self.trainer.train()
        self.model = quant.quantize_dynamic(
            self.model, {nn.Linear}, dtype=getattr(torch, self.quant_type)
        )
        self.trainer.model = self.model

    # Case 2: quantization-aware training. The errors are exactly the same as for static quantization.
    elif self.quantization == "qat":
        self.model.quant = quant.QuantStub()
        self.model.dequant = quant.DeQuantStub()

            # This snippet is necessary in the first place because of https://discuss.pytorch.org/t/89154, otherwise I get a "AssertionError: The only supported dtype for nnq.Embedding is torch.quint8"
            if isinstance(module, nn.Embedding):
                # Both alternatives here cause different errors
                module.qconfig = None # Causes Error A (presumably because of operating on both quantized and non-quantized tensors)
                module.qconfig = quant.float_qparams_dynamic_qconfig # Causes Error B (why?)

        self.model.qconfig = quant.get_default_qat_qconfig('fbgemm')
        quant.prepare_qat(self.model, inplace=True)

        Trainer.compute_loss = quant_compute_loss
        Trainer.prediction_step = quant_prediction_step
        self.model.train()
        self.trainer.train()
        self.model.eval()

        self.model = quant.convert(self.model, inplace=True)
        self.trainer.model = self.model

    # Case 3: static quantization, not included here. The code is very similar and I get the exact same errors in the same place.
    elif self.quantization == "static":
        pass

    self.trainer.save_model('model')

Error A (quantization disabled for Embedding layers):

Traceback (most recent calls WITHOUT Sacred internals):
  File "/home/username/project/filename.py", line 402, in predict
    batch_pred_ids, label_ids, _ = self.trainer.predict(dataset)
  File "/home/username/.local/lib/python3.6/site-packages/transformers/trainer.py", line 1355, in predict
    return self.prediction_loop(test_dataloader, description="Prediction")
  File "/home/username/.local/lib/python3.6/site-packages/transformers/trainer.py", line 1417, in prediction_loop
    loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
  File "/home/username/project/filename.py", line 79, in quant_prediction_step
    outputs = model(**inputs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/username/.local/lib/python3.6/site-packages/transformers/modeling_bert.py", line 1539, in forward
    return_dict=return_dict,
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/username/.local/lib/python3.6/site-packages/transformers/modeling_bert.py", line 838, in forward
    input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/username/.local/lib/python3.6/site-packages/transformers/modeling_bert.py", line 202, in forward
    embeddings = self.LayerNorm(embeddings)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/quantized/modules/normalization.py", line 25, in forward
    eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point)
RuntimeError: Could not run 'quantized::layer_norm' with arguments from the 'CPU' backend. 'quantized::layer_norm' is only available for these backends: [QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, Tracer, Autocast, Batched,
 VmapMode].

Error B (modified qconfig for embedding layers):

Traceback (most recent calls WITHOUT Sacred internals):
  File "/home/username/project/filename.py", line 401, in predict
    batch_pred_ids, label_ids, _ = self.trainer.predict(dataset)
  File "/home/username/.local/lib/python3.6/site-packages/transformers/trainer.py", line 1355, in predict
    return self.prediction_loop(test_dataloader, description="Prediction")
  File "/home/username/.local/lib/python3.6/site-packages/transformers/trainer.py", line 1417, in prediction_loop
    loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
  File "/home/username/project/filename.py", line 79, in quant_prediction_step
    outputs = model(**inputs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/username/.local/lib/python3.6/site-packages/transformers/modeling_bert.py", line 1539, in forward
    return_dict=return_dict,
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/username/.local/lib/python3.6/site-packages/transformers/modeling_bert.py", line 838, in forward
    input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/username/.local/lib/python3.6/site-packages/transformers/modeling_bert.py", line 201, in forward
    embeddings = inputs_embeds + position_embeddings + token_type_embeddings
RuntimeError: The size of tensor a (1024) must match the size of tensor b (128) at non-singleton dimension 0

Errors A seems understandable to me, but I’m not sure what is going on with that second one. If I check manually, without quantization, inputs_embeds, position_embeddings and token_type_embeddings have shape (8, 128, 768), or sometimes (1, 128, 768) at the point where the error occurs. But with QAT (or static quantization), they have shape (1024, 768) or (128, 768) instead, as if the first two dimensions had been concatenated. Is there a way to change them back to their correct shape, maybe by changing the qconfig? I’d rather not edit the library, especially considering that my model type is determined at runtime.

This is surprising, can you provide a smaller repro so that we can investigate this further, something like this snippet alone:

   if inputs_embeds is None:
        inputs_embeds = self.word_embeddings(input_ids)
    token_type_embeddings = self.token_type_embeddings(token_type_ids)

    embeddings = inputs_embeds + token_type_embeddings

@pie3636 did you resolve this issue? Which iterator are you using to iterate over modules when reassigning the config.