Wav2vec2 quantization dimention error

i have been trying to quantise the WAV2VEC2_ASR_BASE_960H model using static_fx mode but getting the example input size error. not sure what size to give (sample input size is (1,54400) and sampling rate is 16k) . also can anybody help me to get the steps to static quantize the model

import torch
import torchaudio
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model()
from torch.quantization import quantize_fx
import copy
m = copy.deepcopy(model)
qconfig_dict = {"": torch.quantization.get_default_qconfig("fbgemm")}
model_prepared = quantize_fx.prepare_fx(m, qconfig_dict,example_inputs=torch.rand(1,54400))

This doesn’t look like an example input size error, it looks like symbolic tracing error. If you don’t do quantization but do normal fx tracing does your model work?


i tried with

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from torch.quantization import quantize_fx
model_f = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

# test = m_torch.encoder
# test = Transformer()

# from torch.quantization import quantize_fx
# m = copy.deepcopy(m_torch.encoder)

m = copy.deepcopy(model_f)

qconfig_dict = {"": torch.quantization.get_default_qconfig("fbgemm")}
# Prepare
model_prepared = quantize_fx.prepare_fx(m, qconfig_dict,example_inputs = torch.randn(768,512))
# Calibrate - Use representative (validation) data.
with torch.inference_mode():
  for _ in range(10):
    x = torch.randn(768,512)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

getting error

TypeError                                 Traceback (most recent call last)
<ipython-input-7-a2e234fa7675> in <cell line: 17>()
     15 qconfig_dict = {"": torch.quantization.get_default_qconfig("fbgemm")}
     16 # Prepare
---> 17 model_prepared = quantize_fx.prepare_fx(m, qconfig_dict,example_inputs = torch.randn(768,512))
     18 # Calibrate - Use representative (validation) data.
     19 with torch.inference_mode():

10 frames
/usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py in _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask, add_adapter)
   1151         batch_size = attention_mask.shape[0]
-> 1153         attention_mask = torch.zeros(
   1154             (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
   1155         )

TypeError: zeros() received an invalid combination of arguments - got (tuple, device=Attribute, dtype=Attribute), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

for normal tracing also getting same error

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(model)

in that case its not a quantization issue, its a tracing issue, you can either use eager mode quantization (to remove the need for traceability) or get this model to be traceable. I’m not an expert but i think the tracing can’t do different things based on the input data so if you hardcode dtype and device in that line of the model, it may get past that point.

attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=<torch.fp32_or_whatever>, device=<“cpu” probably>