[CTC Loss] CTC Loss not support float16?

Hello. I’m trying to train ASR model by CTC loss.
But when I apply mixed precision training, CTC Loss does not descend and model predicts only Blank for some Epochs in spite of using wav2vec2 pretrained model.

I’m not sure which part disturbs training, but I think covering optimizer and backward by scaler is the critical one.
How should I fix the code to train CTC loss for mixed precision training? Or is there any other problem in my code?

CTC loss used like below. I copied huggingface’s usage.

log_probs = F.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)

with torch.backends.cudnn.flags(enabled=False):
    loss = nn.functional.ctc_loss(
        log_probs,
        flattened_targets,
        feature_lengths,
        target_lengths,
        blank=self.config.blank_idx,
        reduction=self.config.ctc_loss_reduction,
        zero_infinity=self.config.ctc_zero_infinity,
    )

And I covered forward and backward by scaler

from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler(enabled=config.fp16_run)

with autocast(enabled=config.fp16_run):

  predictions = model(inputs)

scaler.scale(predictions['loss']).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

The log_probs tensor should be created by F.log_softmax which will use float32 for the needed numerical stability. ctc_loss should then work on these float32 input tensors and not use float16. Could you check the dtype of log_probs and verify it?

Typically, CTC loss is one of the functions that are more demanding wrt numerical accuracy. This shows for fp32 at times and I imagine fp16 to not be a terribly good option for CTC.
Also, it would be nontrivial to get an increase in compute performance.

Best regards

Thomas

1 Like