CTC loss going down but model predicts only blanks

I am using microsoft’s TrOCR model as base and training it with LoRA and torch.nn.functional.ctc_loss as the loss function, the loss keeps going down but the model keeps outputting blank strings after a few batches, I know there are a lot of similar questions and I tried some of those solutions but nothing seems to work, I’d greatly appreciate any advices

this is my LoRA config

lora_config = LoraConfig(
    inference_mode=False,
    r=64,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["key", "query", "value"],
)

this is my training loop

        dataloader = DataLoader(dataset, batch_size=env.BATCH_SIZE, shuffle=True)
        optimizer = optim.Adam(model.parameters(), lr=env.LEARNING_RATE)

        epoch_counter = 1
        batch_counter = 1
        losses = []
        done_training = False

        while not done_training:
            for images, labels, lengths in dataloader:
                if os.path.exists(env.INTERRUPT_FILE):
                    logger.danger("interrupt detected, exiting.")
                    done_training = True
                if os.path.exists(env.HALT_FILE):
                    save_progress()
                    done_training = True
                if done_training:
                    break
                logger.info(f"epoch: {epoch_counter}, batch: {batch_counter}")
                optimizer.zero_grad()
                outputs = model(pixel_values=images, labels=labels)
                targets = labels.view(-1, labels.shape[-1])
                log_probs = torch.nn.functional.log_softmax(outputs.logits, 2).permute(
                    1, 0, 2
                )
                loss = torch.nn.functional.ctc_loss(
                    log_probs=log_probs,
                    targets=targets,
                    input_lengths=torch.full((labels.shape[0],), labels.shape[1]),
                    target_lengths=lengths,
                    blank=tokenizer.pad_token_id,
                    reduction="mean",
                    zero_infinity=True,
                )
                losses.append(loss.item())
                cur_loss, avg_loss = round(loss.item(), 3), round(
                    sum(losses) / len(losses), 3
                )
                logger.info(f"current loss: {cur_loss}, average loss: {avg_loss}")
                loss.backward()
                optimizer.step()

                # check progress with inference every once in a while
                if batch_counter % env.PROGRESS_CHECK_INTERVAL == 0:
                    if env.USE_SOUND:
                        winsound.Beep(750, 750)
                    for _ in range(env.PROGRESS_CHECK_ITERATIONS):
                        file = random.choice(dataset.files)
                        prediction = infer_file(file)
                        logger.info(
                            f"target: {os.path.splitext(file)[0]}, prediction: {prediction}"
                        )
                logger.log("-" * 80)
                batch_counter += 1
            epoch_counter += 1
            batch_counter = 0