Hello,
I’m using CTC loss for ASR, for phoneme recognition. My model is a pretrained Whisper tiny encoder (not frozen during training) and a linear classification head as a decoder, with a LogSoftmax acitvation layer. My dataset is L2-Arctic.
I have looked everywhere for the exact thing to use as inputs for pytorch’s CTCLoss, but it was not very clearly explained anywhere. I can’t find the torch.ctc_loss() function source code, only the wrappers.
I am currently using sequences of classes deduced from the max predicted probability for each window, with blank padding (class n°0). For the input lengths, I am taking these inputs and deleting repetitions as well and blanks.
- Should I input something else ?
- Should the input lengths take repetitions into account ?
- Should the input lengths take blanks into account ?
As for the targets, I am inputing classes without the blanks as well:
target_lengths = torch.count_nonzero(batch["phonemes"], axis=1)
I use the zero_infinity=True
flag because after some time my model predicts so much silence that the input lengths are shorter than targets.
I am asking this because my model learns to predict too much silence, and the wrong classes. Using the full predictions lengths without removing blanks nor repetitions made the model learn to predict the same class for every window. (ex: predicting 2222222222222222222222222222)