Torch.jit.script throws error related to pytorch code

I’m seeing the following error when trying to compile a model to torchscript using torch.jit.script:

Traceback (most recent call last):                                                                                                     
  File "espnet_to_jit.py", line 31, in <module>
    script_model = torch.jit.script(asr_model)                                                                                         
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_script.py", line 943, in script
    obj, torch.jit._recursive.infer_methods_to_compile                                                                                                   
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 391, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)                                                               
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 448, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)                                                    
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_script.py", line 391, in _construct
    init_fn(script_module)                                                                                                                               
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 428, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)                                                      
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 448, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)                                                    
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_script.py", line 391, in _construct
    init_fn(script_module)                                                                                                                               
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 428, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)                                                      
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 448, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)                                                    
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_script.py", line 391, in _construct
    init_fn(script_module)                                                                                                                               
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 428, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)                                                                                       
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 448, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)                                                                   
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_script.py", line 391, in _construct
    init_fn(script_module)                                                                                                                             
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 428, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)                                                              
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 452, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)                                               
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_recursive.py", line 335, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)                
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_script.py", line 1106, in _recursive_compile_class
    _compile_and_register_class(obj, rcb, _qual_name)                                                                                         
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/_script.py", line 65, in _compile_and_register_class
    ast = get_jit_class_def(obj, obj.__name__)                                              
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/frontend.py", line 176, in get_jit_class_def
    is_classmethod=is_classmethod(method[1])) for method in methods]          
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/frontend.py", line 176, in <listcomp>
    is_classmethod=is_classmethod(method[1])) for method in methods]                                       
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/frontend.py", line 248, in get_jit_def
    type_line = torch.jit.annotations.get_type_line(source)                                                    
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/jit/annotations.py", line 205, in get_type_line
    "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)")  # noqa
RuntimeError: Return type line '# type: (...) -> ...' not found on multiline type annotation
for type lines:                                                               
            normalized_shape = (normalized_shape,)  # type: ignore[assignment]   
        self.normalized_shape = tuple(normalized_shape)  # type: ignore[arg-type]
(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)

I added to the code in annotations.py to print out the source of the error as well, and it shows:

Source is:
    def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
        super(LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            # mypy error: incompatible types in assignment
            normalized_shape = (normalized_shape,)  # type: ignore[assignment]
        self.normalized_shape = tuple(normalized_shape)  # type: ignore[arg-type]
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = Parameter(torch.Tensor(*self.normalized_shape))
            self.bias = Parameter(torch.Tensor(*self.normalized_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.reset_parameters()

I tracked down the source of this code to torch/nn/modules/normalization.py.

Why would torch.jit.script be tracing into the pytorch code, and how can I fix this error?

I was able to move past this issue by removing the # type: ignore[] comments, but am immediately presented with another error in compiling pytorch code:

torch.jit.frontend.NotSupportedError: keyword-arg expansion is not supported:
  File "/shared/workspaces/rwesterman/miniconda3/envs/espnet/lib/python3.7/site-packages/torch/nn/modules/normalization.py", line 175
    def extra_repr(self) -> str:
        return '{normalized_shape}, eps={eps}, ' \
            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
                                                               ~~~~~~~~~~~~~ <--- HERE
'__torch__.torch.nn.modules.normalization.LayerNorm' is being compiled since it was called from 'LayerNorm.forward'
  File "/shared/workspaces/rwesterman/rev-espnet/espnet/nets/pytorch_backend/transformer/layer_norm.py", line 37
        """
        if self.dim == -1:
            return super(torch.nn.LayerNorm, self).forward(x)
                         ~~~~~~~~~~~~~~~~~~ <--- HERE
        return super(torch.nn.LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)