The error appears when we work on the below code:
wav, rate = TTSHubInterface.get_prediction(task, models[0], generator, sample)
in the sample argument we have three values:
- src_tokenns
- src_lengths
- speaker
these are tensors that should be on the same device like in below:
models = [models[0].to(device)]
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
generator = task.build_generator(models, cfg)
sample = TTSHubInterface.get_model_input(task, text)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample['net_input']['src_tokens'] = sample['net_input']['src_tokens'].to(device)
sample['net_input']['src_lengths'] = sample['net_input']['src_lengths'].to(device)
sample['speaker'] = sample['speaker'].to(device)