RuntimeError: The size of tensor a (2) must match the size of tensor b (0) at non-singleton dimension 1

I am attempting to get verbatim transcripts from mp3 files using CrisperWhisper through Transformers. I am receiving this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 5
      2 output_txt = r"C:\Users\pryce\PycharmProjects\LostInTranscription\data\WER0\001_test.txt"
      4 print("Transcribing:", audio_file)
----> 5 transcript_text = transcribe_audio(audio_file, asr_pipe, chunk_length_s = 30, overlap_s = 0.5)
      7 # display first 50 lines
      8 print("---- Transcript preview ----")

Cell In[8], line 19, in transcribe_audio(audio_path, asr_pipeline, chunk_length_s, overlap_s)
     17 for chunk, start in audio_chunks:
     18     sample = {"array": chunk.astype("float32"), "sampling_rate": sr}
---> 19     hf_out = asr_pipeline(sample)
     20     print(hf_out)
     22     # shift timestamps by chunk_start

File ~\PycharmProjects\LostInTranscription\LIT\Lib\site-packages\transformers\pipelines\automatic_speech_recognition.py:275, in AutomaticSpeechRecognitionPipeline.__call__(self, inputs, **kwargs)
    218 def __call__(self, inputs: Union[np.ndarray, bytes, str, dict], **kwargs: Any) -> list[dict[str, Any]]:
    219     """
    220     Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`]
    221     documentation for more information.
   (...)    273                 `"".join(chunk["text"] for chunk in output["chunks"])`.
    274     """
--> 275     return super().__call__(inputs, **kwargs)

File ~\PycharmProjects\LostInTranscription\LIT\Lib\site-packages\transformers\pipelines\base.py:1459, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1457     return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
   1458 elif self.framework == "pt" and isinstance(self, ChunkPipeline):
-> 1459     return next(
   1460         iter(
   1461             self.get_iterator(
   1462                 [inputs], num_workers, batch_size, preprocess_params, forward_params, postprocess_params
   1463             )
   1464         )
   1465     )
   1466 else:
   1467     return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File ~\PycharmProjects\LostInTranscription\LIT\Lib\site-packages\transformers\pipelines\pt_utils.py:126, in PipelineIterator.__next__(self)
    123     return self.loader_batch_item()
    125 # We're out of items within a batch
--> 126 item = next(self.iterator)
    127 processed = self.infer(item, **self.params)
    128 # We now have a batch of "inferred things".

File ~\PycharmProjects\LostInTranscription\LIT\Lib\site-packages\transformers\pipelines\pt_utils.py:271, in PipelinePackIterator.__next__(self)
    268             return accumulator
    270 while not is_last:
--> 271     processed = self.infer(next(self.iterator), **self.params)
    272     if self.loader_batch_size is not None:
    273         if isinstance(processed, torch.Tensor):

File ~\PycharmProjects\LostInTranscription\LIT\Lib\site-packages\transformers\pipelines\base.py:1374, in Pipeline.forward(self, model_inputs, **forward_params)
   1372     with inference_context():
   1373         model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
-> 1374         model_outputs = self._forward(model_inputs, **forward_params)
   1375         model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
   1376 else:

File ~\PycharmProjects\LostInTranscription\LIT\Lib\site-packages\transformers\pipelines\automatic_speech_recognition.py:535, in AutomaticSpeechRecognitionPipeline._forward(self, model_inputs, return_timestamps, **generate_kwargs)
    529 main_input_name = self.model.main_input_name if hasattr(self.model, "main_input_name") else "inputs"
    530 generate_kwargs = {
    531     main_input_name: inputs,
    532     "attention_mask": attention_mask,
    533     **generate_kwargs,
    534 }
--> 535 tokens = self.model.generate(**generate_kwargs)
    537 # whisper longform generation stores timestamps in "segments"
    538 if return_timestamps == "word" and self.type == "seq2seq_whisper":

File ~\PycharmProjects\LostInTranscription\LIT\Lib\site-packages\transformers\models\whisper\generation_whisper.py:866, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, time_precision_features, return_token_timestamps, return_segments, return_dict_in_generate, force_unique_generate_call, monitor_progress, **kwargs)
    857             proc.set_begin_index(decoder_input_ids.shape[-1])
    859 # 6.6 Run generate with fallback
    860 (
    861     seek_sequences,
    862     seek_outputs,
    863     should_skip,
    864     do_condition_on_prev_tokens,
    865     model_output_type,
--> 866 ) = self.generate_with_fallback(
    867     segment_input=segment_input,
    868     decoder_input_ids=decoder_input_ids,
    869     cur_bsz=cur_bsz,
    870     seek=seek,
    871     batch_idx_map=batch_idx_map,
    872     temperatures=temperatures,
    873     generation_config=generation_config,
    874     logits_processor=logits_processor,
    875     stopping_criteria=stopping_criteria,
    876     prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    877     synced_gpus=synced_gpus,
    878     return_token_timestamps=return_token_timestamps,
    879     do_condition_on_prev_tokens=do_condition_on_prev_tokens,
    880     is_shortform=is_shortform,
    881     batch_size=batch_size,
    882     attention_mask=attention_mask,
    883     kwargs=kwargs,
    884 )
    886 # 6.7 In every generated sequence, split by timestamp tokens and extract segments
    887 for i, seek_sequence in enumerate(seek_sequences):

File ~\PycharmProjects\LostInTranscription\LIT\Lib\site-packages\transformers\models\whisper\generation_whisper.py:1053, in WhisperGenerationMixin.generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, seek, batch_idx_map, temperatures, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_token_timestamps, do_condition_on_prev_tokens, is_shortform, batch_size, attention_mask, kwargs)
   1050 model_output_type = type(seek_outputs)
   1052 # post-process sequence tokens and outputs to be in list form
-> 1053 seek_sequences, seek_outputs = self._postprocess_outputs(
   1054     seek_outputs=seek_outputs,
   1055     decoder_input_ids=decoder_input_ids,
   1056     return_token_timestamps=return_token_timestamps,
   1057     generation_config=generation_config,
   1058     is_shortform=is_shortform,
   1059     seek=seek,
   1060     batch_idx_map=batch_idx_map,
   1061 )
   1063 if cur_bsz < batch_size:
   1064     seek_sequences = seek_sequences[:cur_bsz]

File ~\PycharmProjects\LostInTranscription\LIT\Lib\site-packages\transformers\models\whisper\generation_whisper.py:1163, in WhisperGenerationMixin._postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform, seek, batch_idx_map)
   1160         num_frames = num_frames - seek
   1161         num_frames = num_frames[batch_idx_map]
-> 1163     seek_outputs["token_timestamps"] = self._extract_token_timestamps(
   1164         seek_outputs,
   1165         generation_config.alignment_heads,
   1166         num_frames=num_frames,
   1167         num_input_ids=decoder_input_ids.shape[-1],
   1168     )
   1170 def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None):
   1171     if beam_indices is not None and key == "scores":

File ~\PycharmProjects\LostInTranscription\LIT\Lib\site-packages\transformers\models\whisper\generation_whisper.py:285, in WhisperGenerationMixin._extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision, num_frames, num_input_ids)
    281 if num_input_ids is not None and num_input_ids > 1:
    282     # `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1
    283     weight_length += num_input_ids - 1
    284     beam_indices_first_step_unrolled = (
--> 285         torch.ones(beam_indices.shape[0], num_input_ids - 1, device=beam_indices.device, dtype=torch.long)
    286         * (beam_indices[:, 0:1])
    287     )
    288     unrolled_beam_indices = torch.cat([beam_indices_first_step_unrolled, beam_indices], dim=-1)
    289 else:

RuntimeError: The size of tensor a (2) must match the size of tensor b (0) at non-singleton dimension 1

My model configuration is the same as on the CrisperWhisper model card:

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if "cuda" in device else torch.float32
model_id = "nyrahealth/CrisperWhisper"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype = torch_dtype,
    low_cpu_mem_usage = True,
    use_safetensors = True,
)

model.to(device)
processor = AutoProcessor.from_pretrained(model_id)

asr_pipe = pipeline(
    task = "automatic-speech-recognition",
    model = model,
    tokenizer = processor.tokenizer,
    feature_extractor = processor.feature_extractor,
    chunk_length_s = 30,
    batch_size = 16,
    return_timestamps = "word",
    device = 0 if "cuda" in device else -1,
)

My audio file is being loaded through librosa, resampled to a sampling rate of 16,000 with 30 second chunks. Short chunks are padded with zeroes. My call to the pipeline is as follows:

sample = {"array": chunk.astype("float32"), "sampling_rate": sr}
hf_out = asr_pipeline(sample)

So far, I have:

  • Verified that my audio is loading with the correct shape
  • Verified that there are no chunks with a length other than chunk_length_s * sr (30 x 16,000 in this case)
  • Replaced “array” in the definition of sample with “raw”
  • Defined sample without converting chunk to float32 (also tried float16)

None of the above have had any impact on the provided error message. My main issue is that I don’t know what tensors a and b are referring to here, I don’t know what would have a length of 0 or 2. Any insights would be appreciated!