here’s my code
from data_loader import train_dataloader
from torchaudio.prototype.models import conformer_rnnt_model
from torch.optim import AdamW
from pytorch_lightning import LightningModule
from torchaudio.functional import rnnt_loss
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor,ModelCheckpoint
import torch
torch.set_float32_matmul_precision("medium")
none_count = 0
class MyConformer(LightningModule):
def __init__(self,lr=1e-3):
super(MyConformer, self).__init__()
self.conformer = conformer_rnnt_model(input_dim=80,
encoding_dim=256,
conformer_depthwise_conv_kernel_size=31,
conformer_dropout=0.1,
conformer_ffn_dim=1200,
conformer_input_dim=256,
conformer_num_heads=4,
conformer_num_layers=16,
joiner_activation='relu',
lstm_dropout=0,
lstm_hidden_dim=640,
lstm_layer_norm=False,
lstm_layer_norm_epsilon=1e-8,
num_lstm_layers=1,
num_symbols=128,
symbol_embedding_dim=512,
time_reduction_stride=3,)
self.lr = lr
def forward(self, x,x_lengths,y,y_lengths):
output,output_source_length,output_target_length,_ = self.conformer(x,x_lengths,y,y_lengths)
return output.half(),output_source_length,output_target_length
def training_step(self, batch, batch_idx):
spectrograms, transcriptions, specs_lengths, transcriptions_lengths = batch
outputs,output_source_length,output_target_length = self(spectrograms,specs_lengths,transcriptions,transcriptions_lengths)
loss = rnnt_loss(outputs, transcriptions, output_source_length, transcriptions_lengths)
if torch.isnan(loss) or torch.isinf(loss):
global none_count
none_count+=1
self.log('skipped_samples',float(none_count),prog_bar=True)
return None
self.log('loss',loss,batch_size=spectrograms.shape[0],prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.lr,betas=[0.9,0.98],eps=1e-9,weight_decay=1e-3)
return optimizer
# trainer.fit(model, train_dataloader)
if __name__ == '__main__':
callbacks= [LearningRateMonitor(logging_interval='epoch'),
ModelCheckpoint(dirpath="./checkpoints_v6",verbose=True,save_on_train_epoch_end=True,save_top_k=1,save_last=True,monitor='val_loss'),
]
model = MyConformer(5.0)
trainer = Trainer(accelerator='auto',
precision="bf16",
max_epochs=50,callbacks=callbacks,
default_root_dir='./checkpoints/logs',
)
# torch.save(model.state_dict(),"conformer_v6")
trainer.fit(model, train_dataloaders=train_dataloader)
i’m getting the error : RuntimeError: output length mismatch,
hope someone can help me! thanks in advance for any help.