Quantise the wav2vec2 model

I am trying to quantise the wav2vec2 model using static fx mode. by=ut encountering the error: ,

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from torch.quantization import quantize_fx
model_f = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

from torch.quantization import quantize_fx
m = copy.deepcopy(model_f)
m.eval()
qconfig_dict = {"": torch.quantization.get_default_qconfig("fbgemm")}
# Prepare
model_prepared = quantize_fx.prepare_fx(m, qconfig_dict,example_inputs = torch.randn(1,1,54400))
# Calibrate - Use representative (validation) data.
with torch.inference_mode():
  for _ in range(10):
    x = torch.rand(1,1,54400)
    model_prepared(x)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
TypeError                                 Traceback (most recent call last)
<ipython-input-26-0604642dc2ff> in <cell line: 10>()
      8 qconfig_dict = {"": torch.quantization.get_default_qconfig("fbgemm")}
      9 # Prepare
---> 10 model_prepared = quantize_fx.prepare_fx(m, qconfig_dict,example_inputs = torch.randn(1,1,54400))
     11 # Calibrate - Use representative (validation) data.
     12 with torch.inference_mode():

10 frames
/usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py in _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask, add_adapter)
   1151         batch_size = attention_mask.shape[0]
   1152 
-> 1153         attention_mask = torch.zeros(
   1154             (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
   1155         )

TypeError: zeros() received an invalid combination of arguments - got (tuple, device=Attribute, dtype=Attribute), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)