Trying to train a model using rnnt loss, getting RuntimeError: output length mismatch

here’s my code

from data_loader import train_dataloader
from torchaudio.prototype.models import conformer_rnnt_model
from torch.optim import AdamW
from pytorch_lightning import LightningModule
from torchaudio.functional import rnnt_loss
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor,ModelCheckpoint
import torch


torch.set_float32_matmul_precision("medium")
none_count = 0


class MyConformer(LightningModule):

    def __init__(self,lr=1e-3):
        super(MyConformer, self).__init__()

        self.conformer = conformer_rnnt_model(input_dim=80,
        encoding_dim=256,
        conformer_depthwise_conv_kernel_size=31,
        conformer_dropout=0.1,
        conformer_ffn_dim=1200,
        conformer_input_dim=256,
        conformer_num_heads=4,
        conformer_num_layers=16,
        joiner_activation='relu',
        lstm_dropout=0,
        lstm_hidden_dim=640,
        lstm_layer_norm=False,
        lstm_layer_norm_epsilon=1e-8,
        num_lstm_layers=1,
        num_symbols=128,
        symbol_embedding_dim=512,
        time_reduction_stride=3,)
        self.lr = lr

   
    def forward(self, x,x_lengths,y,y_lengths):

        output,output_source_length,output_target_length,_ = self.conformer(x,x_lengths,y,y_lengths)
        return output.half(),output_source_length,output_target_length

    
    def training_step(self, batch, batch_idx):

        spectrograms, transcriptions, specs_lengths, transcriptions_lengths = batch
        outputs,output_source_length,output_target_length = self(spectrograms,specs_lengths,transcriptions,transcriptions_lengths)
        loss =  rnnt_loss(outputs, transcriptions, output_source_length, transcriptions_lengths)


        if torch.isnan(loss) or torch.isinf(loss):
            global none_count
            none_count+=1
            self.log('skipped_samples',float(none_count),prog_bar=True)
            return None
        self.log('loss',loss,batch_size=spectrograms.shape[0],prog_bar=True)

        return loss



    def configure_optimizers(self):

        optimizer = AdamW(self.parameters(), lr=self.lr,betas=[0.9,0.98],eps=1e-9,weight_decay=1e-3)

        return optimizer



    



# trainer.fit(model, train_dataloader)
if __name__ == '__main__':
    
    callbacks= [LearningRateMonitor(logging_interval='epoch'),
        ModelCheckpoint(dirpath="./checkpoints_v6",verbose=True,save_on_train_epoch_end=True,save_top_k=1,save_last=True,monitor='val_loss'),
        ]

    model = MyConformer(5.0)
    trainer = Trainer(accelerator='auto',
    precision="bf16",
    max_epochs=50,callbacks=callbacks,
    default_root_dir='./checkpoints/logs',

    )
    # torch.save(model.state_dict(),"conformer_v6")


    trainer.fit(model, train_dataloaders=train_dataloader)


i’m getting the error : RuntimeError: output length mismatch,
hope someone can help me! thanks in advance for any help.

Your code snippet doesn’t give any information where the issue might come from, but I guess it could be related to output_source_length,output_target_length based on the error message.
Do you know if these tensor values are checked and compared against each other somewhere inside the Trainer class and did you check if these values are equal (or are supposed to be equal)?

Thanks for your response. i don’t think it’s related to those, because if i try to modify on of those i get errors (input length mismatch/target length mismatch). i think the issue is with the shape of the output from the model. since in the docs, the logits argument for rnn-t loss is (batch, max seq length, max target length + 1, class), and i found it to be (batch, max seq length, max target length , class) , so for some reasons it’s missing the +1. here is the output of print(outputs.shape) and print(max(output_target_length)) :
torch.Size([2, 44, 14, 128])
tensor(14, device=‘cuda:0’, dtype=torch.int32).
it should be 15 intead of 14 in outputs.shape.
i also tried with an rnnt loss function provided by a third party package and i get the same error.
so if i’m not missing anything i think it’s a bug in torchaudio.prototype.models.conformer_rnnt_model, and since it’s an experimental feature i would’nt be too surprised.

I don’t know what might be causing the issue as I’m not deeply familiar with this model but the current nightly binaries pass all RNNT-related unit tests:

pytest test/torchaudio_unittest/prototype/ -v -k rnnt
========================================================= test session starts ==========================================================
platform linux -- Python 3.8.15, pytest-7.2.2, pluggy-1.0.0 -- /home/pbialecki/miniforge3/envs/nightly_pip_cuda118/bin/python3.8
cachedir: .pytest_cache
rootdir: /home/pbialecki/libs/upstream/audio
collected 1166 items / 1134 deselected / 32 selected                                                                                   

test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat32CPUTest::test_output_shape_forward PASSED               [  3%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat32CPUTest::test_output_shape_join PASSED                  [  6%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat32CPUTest::test_output_shape_predict PASSED               [  9%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat32CPUTest::test_output_shape_transcribe PASSED            [ 12%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat32CPUTest::test_torchscript_consistency_forward PASSED    [ 15%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat32CPUTest::test_torchscript_consistency_join PASSED       [ 18%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat32CPUTest::test_torchscript_consistency_predict PASSED    [ 21%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat32CPUTest::test_torchscript_consistency_transcribe PASSED [ 25%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat64CPUTest::test_output_shape_forward PASSED               [ 28%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat64CPUTest::test_output_shape_join PASSED                  [ 31%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat64CPUTest::test_output_shape_predict PASSED               [ 34%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat64CPUTest::test_output_shape_transcribe PASSED            [ 37%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat64CPUTest::test_torchscript_consistency_forward PASSED    [ 40%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat64CPUTest::test_torchscript_consistency_join PASSED       [ 43%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat64CPUTest::test_torchscript_consistency_predict PASSED    [ 46%]
test/torchaudio_unittest/prototype/rnnt_cpu_test.py::ConformerRNNTFloat64CPUTest::test_torchscript_consistency_transcribe PASSED [ 50%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat32GPUTest::test_output_shape_forward PASSED               [ 53%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat32GPUTest::test_output_shape_join PASSED                  [ 56%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat32GPUTest::test_output_shape_predict PASSED               [ 59%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat32GPUTest::test_output_shape_transcribe PASSED            [ 62%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat32GPUTest::test_torchscript_consistency_forward PASSED    [ 65%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat32GPUTest::test_torchscript_consistency_join PASSED       [ 68%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat32GPUTest::test_torchscript_consistency_predict PASSED    [ 71%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat32GPUTest::test_torchscript_consistency_transcribe PASSED [ 75%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat64GPUTest::test_output_shape_forward PASSED               [ 78%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat64GPUTest::test_output_shape_join PASSED                  [ 81%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat64GPUTest::test_output_shape_predict PASSED               [ 84%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat64GPUTest::test_output_shape_transcribe PASSED            [ 87%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat64GPUTest::test_torchscript_consistency_forward PASSED    [ 90%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat64GPUTest::test_torchscript_consistency_join PASSED       [ 93%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat64GPUTest::test_torchscript_consistency_predict PASSED    [ 96%]
test/torchaudio_unittest/prototype/rnnt_gpu_test.py::ConformerRNNTFloat64GPUTest::test_torchscript_consistency_transcribe PASSED [100%]

================================================= 32 passed, 1134 deselected in 22.67s =================================================

Maybe @nateanl might know what could be causing the issue or if the setup is unsupported.

1 Like

Hi @g1777, I guess I know where the issue is. You are right, the logits shape should be (batch, max seq length, max target length + 1, class). It is reflected in audio/lightning.py at main · pytorch/audio · GitHub.

Basically you just need to increase the length by 1 when calling the forward method of conformer model. That should fix the issue.

1 Like

That was the issue @nateanl , i added
tensor2 = torch.cat([torch.zeros([x.shape[0], 1],dtype=torch.int32).to('cuda')], dim=1)
y = torch.cat((y, tensor2), dim=1)
to the forward method the make the y tensor lengths longer by 1 and it’s working now. Thank you all for your help.