RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x4096 and 320x256) in conformer_rnnt_base transcribe method

i’m using

spectrogram = log(MelSpectrogram(n_mels=80, n_fft=1024, sample_rate=sr)(waveform)+ 1e-10).squeeze()

as input for the model.
the error occurs here when calling decoder.forward

decoder = RNNTBeamSearch(
    model = model,
    blank=10,
    step_max_tokens=24,


)


with torch.no_grad():

    all_transcriptions = []
    all_targets = []
    for i, batch in enumerate(test_dataloader):

        inputs, targets, input_lengths, target_lengths  = batch
        transcriptions,transcriptions_lengths = model.transcribe(inputs,input_lengths)
        print(transcriptions_lengths)
        
        results = []
        for i in range(0,len(transcriptions)):
            result = decoder.forward(transcriptions,transcriptions_lengths,128)
            results.append(result)

i also tried with conformer_rnnt_model and the only case where it works is when i set both input_dim=80 and
encoding_dim=80.
thanks in advance for any help.