Unable to JIT trace Tacotron 2 (from `torchaudio` library)

Hi everyone!

I am unable to trace the Tacotron 2 model given in torchaudio library.

Here’s my approach

import torch
from torch import nn
import torchaudio

device = "cuda" if torch.cuda.is_available() else "cpu"

bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)

text = "Hello world! Text to speech!"
processed, lengths = processor(text)

Let’s define a model wrapper for tracing

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = tacotron2
    def forward(self, processed, lengths):
        spec, _, _ = self.model.infer(processed, lengths)
        return spec

model = Model()
spec = model(processed, lengths)
print(spec.shape) # torch.Size([1, 80, 193])

Let’s try to trace it

traced_model = torch.jit.trace(model, (processed, lengths))

Now I get the error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[29], line 1
----> 1 traced_model = torch.jit.trace(model, (processed, lengths))

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:794, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
    792         else:
    793             raise RuntimeError("example_kwarg_inputs should be a dict")
--> 794     return trace_module(
    795         func,
    796         {"forward": example_inputs},
    797         None,
    798         check_trace,
    799         wrap_check_inputs(check_inputs),
    800         check_tolerance,
    801         strict,
    802         _force_outplace,
    803         _module_class,
    804         example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
    805         _store_inputs=_store_inputs
    806     )
    807 if (
    808     hasattr(func, "__self__")
    809     and isinstance(func.__self__, torch.nn.Module)
    810     and func.__name__ == "forward"
    811 ):
    812     if example_inputs is None:

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:1023, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
   1020 torch.jit._trace._trace_module_map = trace_module_map
   1021 register_submods(mod, "__module")
-> 1023 module = make_module(mod, _module_class, _compilation_unit)
   1025 for method_name, example_inputs in inputs.items():
   1026     if method_name == "forward":
   1027         # "forward" is a special case because we need to trace
   1028         # `Module.__call__`, which sets up some extra tracing, but uses
   1029         # argument names of the real `Module.forward` method.

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:604, in make_module(mod, _module_class, _compilation_unit)
    602 if _module_class is None:
    603     _module_class = TopLevelTracedModule
--> 604 return _module_class(mod, _compilation_unit=_compilation_unit)

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:1169, in TracedModule.__init__(self, orig, id_set, _compilation_unit)
   1167     if submodule is None:
   1168         continue
-> 1169     tmp_module._modules[name] = make_module(
   1170         submodule, TracedModule, _compilation_unit=None
   1171     )
   1173 script_module = torch.jit._recursive.create_script_module(
   1174     tmp_module, lambda module: (), share_types=False, is_tracing=True
   1175 )
   1177 self.__dict__["_name"] = type(orig).__name__

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py:595, in make_module(mod, _module_class, _compilation_unit)
    592 elif torch._jit_internal.module_has_exports(mod):
    594     infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
--> 595     return torch.jit._recursive.create_script_module(
    596         mod,
    597         infer_methods_stubs_fn,
    598         share_types=False,
    599         is_tracing=True
    600     )
    601 else:
    602     if _module_class is None:

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_recursive.py:480, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    478 if not is_tracing:
    479     AttributeTypeIsSupportedChecker().check(nn_module)
--> 480 return create_script_module_impl(nn_module, concrete_type, stubs_fn)

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_recursive.py:546, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    544 # Compile methods if necessary
    545 if concrete_type not in concrete_type_store.methods_compiled:
--> 546     create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    547     # Create hooks after methods to ensure no name collisions between hooks and methods.
    548     # If done before, hooks can overshadow methods that aren't exported.
    549     create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_recursive.py:397, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    394 property_defs = [p.def_ for p in property_stubs]
    395 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 397 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_recursive.py:898, in compile_unbound_method(concrete_type, fn)
    894 stub = make_stub(fn, fn.__name__)
    895 with torch._jit_internal._disable_emit_hooks():
    896     # We don't want to call the hooks here since the graph that is calling
    897     # this function is not yet complete
--> 898     create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
    899 return stub

File /opt/homebrew/lib/python3.9/site-packages/torch/jit/_recursive.py:397, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    394 property_defs = [p.def_ for p in property_stubs]
    395 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 397 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)

RuntimeError: Couldn't find method: 'forward__0' on class: '__torch__.torch.nn.modules.rnn.LSTM (of Python compilation unit at: 0x6000007dcb58)'

Main error line is

RuntimeError: Couldn't find method: 'forward__0' on class: '__torch__.torch.nn.modules.rnn.LSTM (of Python compilation unit at: 0x6000007dcb58)'

Could any please help me solve this?

Thanks,
Rahul Bhalley