(facebook/tts_transformer-zh-cv7_css10) Expected tensors should all on the same device

The error appears when we work on the below code:

wav, rate = TTSHubInterface.get_prediction(task, models[0], generator, sample)

in the sample argument we have three values:

  1. src_tokenns
  2. src_lengths
  3. speaker

these are tensors that should be on the same device like in below:

models = [models[0].to(device)]
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
generator = task.build_generator(models, cfg)

sample = TTSHubInterface.get_model_input(task, text)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample['net_input']['src_tokens'] = sample['net_input']['src_tokens'].to(device)

sample['net_input']['src_lengths'] = sample['net_input']['src_lengths'].to(device)

sample['speaker'] = sample['speaker'].to(device)