Hi everyone!
I am unable to trace the Tacotron 2 model given in torchaudio
library.
Here’s my approach
import torch
from torch import nn
import torchaudio
device = "cuda" if torch.cuda.is_available() else "cpu"
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)
text = "Hello world! Text to speech!"
processed, lengths = processor(text)
Let’s define a model wrapper for tracing
class Model(nn.Module):
def __init__(self):
super().__init__()
self.model = tacotron2
def forward(self, processed, lengths):
spec, _, _ = self.model.infer(processed, lengths)
return spec
model = Model()
spec = model(processed, lengths)
print(spec.shape) # torch.Size([1, 80, 193])
Let’s try to trace it
traced_model = torch.jit.trace(model, (processed, lengths))
Now I get the error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[29], line 1
----> 1 traced_model = torch.jit.trace(model, (processed, lengths))
File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:794, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
792 else:
793 raise RuntimeError("example_kwarg_inputs should be a dict")
--> 794 return trace_module(
795 func,
796 {"forward": example_inputs},
797 None,
798 check_trace,
799 wrap_check_inputs(check_inputs),
800 check_tolerance,
801 strict,
802 _force_outplace,
803 _module_class,
804 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
805 _store_inputs=_store_inputs
806 )
807 if (
808 hasattr(func, "__self__")
809 and isinstance(func.__self__, torch.nn.Module)
810 and func.__name__ == "forward"
811 ):
812 if example_inputs is None:
File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:1023, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
1020 torch.jit._trace._trace_module_map = trace_module_map
1021 register_submods(mod, "__module")
-> 1023 module = make_module(mod, _module_class, _compilation_unit)
1025 for method_name, example_inputs in inputs.items():
1026 if method_name == "forward":
1027 # "forward" is a special case because we need to trace
1028 # `Module.__call__`, which sets up some extra tracing, but uses
1029 # argument names of the real `Module.forward` method.
File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:604, in make_module(mod, _module_class, _compilation_unit)
602 if _module_class is None:
603 _module_class = TopLevelTracedModule
--> 604 return _module_class(mod, _compilation_unit=_compilation_unit)
File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:1169, in TracedModule.__init__(self, orig, id_set, _compilation_unit)
1167 if submodule is None:
1168 continue
-> 1169 tmp_module._modules[name] = make_module(
1170 submodule, TracedModule, _compilation_unit=None
1171 )
1173 script_module = torch.jit._recursive.create_script_module(
1174 tmp_module, lambda module: (), share_types=False, is_tracing=True
1175 )
1177 self.__dict__["_name"] = type(orig).__name__
File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:595, in make_module(mod, _module_class, _compilation_unit)
592 elif torch._jit_internal.module_has_exports(mod):
594 infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
--> 595 return torch.jit._recursive.create_script_module(
596 mod,
597 infer_methods_stubs_fn,
598 share_types=False,
599 is_tracing=True
600 )
601 else:
602 if _module_class is None:
File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_recursive.py:480, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
478 if not is_tracing:
479 AttributeTypeIsSupportedChecker().check(nn_module)
--> 480 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_recursive.py:546, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
544 # Compile methods if necessary
545 if concrete_type not in concrete_type_store.methods_compiled:
--> 546 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
547 # Create hooks after methods to ensure no name collisions between hooks and methods.
548 # If done before, hooks can overshadow methods that aren't exported.
549 create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_recursive.py:397, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
394 property_defs = [p.def_ for p in property_stubs]
395 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 397 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_recursive.py:898, in compile_unbound_method(concrete_type, fn)
894 stub = make_stub(fn, fn.__name__)
895 with torch._jit_internal._disable_emit_hooks():
896 # We don't want to call the hooks here since the graph that is calling
897 # this function is not yet complete
--> 898 create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
899 return stub
File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_recursive.py:397, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
394 property_defs = [p.def_ for p in property_stubs]
395 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 397 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: Couldn't find method: 'forward__0' on class: '__torch__.torch.nn.modules.rnn.LSTM (of Python compilation unit at: 0x6000007dcb58)'
Main error line is
RuntimeError: Couldn't find method: 'forward__0' on class: '__torch__.torch.nn.modules.rnn.LSTM (of Python compilation unit at: 0x6000007dcb58)'
Could any please help me solve this?
Thanks,
Rahul Bhalley