Try to use torchscript on torch.nn.transformer

Hello everyone. I attempt to use torch.jit.script on the torch.nn.transformer, but it doesn’t work。
Has anyone ever done any related work?

I build the pytorch from source and the torch version is 1.4.0a0+2e7dd54

I’d appreciate if anybody can help me! Or if there is a workable implementation, please let me know! Thanks in advance!

here is the code:

import torch
import torch.nn as nn
torch.manual_seed(2)
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
trans_model = torch.jit.script(transformer_model)
trans_model.save(‘test.pt’)

and here is the log:

Traceback (most recent call last):
File “transformer_demo.py”, line 13, in
trans_model = torch.jit.script(transformer_model)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/init.py”, line 1239, in script
return torch.jit.torch.jit._recursive.recursive_script(obj)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 508, in recursive_script
return create_script_module(nn_module, infer_methods_to_compile(nn_module))
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 305, in create_script_module
concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 243, in get_or_create_concrete_type
raw_concrete_type = infer_raw_concrete_type(nn_module)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 91, in infer_raw_concrete_type
sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 243, in get_or_create_concrete_type
raw_concrete_type = infer_raw_concrete_type(nn_module)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 91, in infer_raw_concrete_type
sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 240, in get_or_create_concrete_type
scripted = create_constant_iterable_module(nn_module)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 539, in create_constant_iterable_module
modules[key] = recursive_script(submodule)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 508, in recursive_script
return create_script_module(nn_module, infer_methods_to_compile(nn_module))
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 308, in create_script_module
return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 358, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/init.py”, line 1612, in _construct
init_fn(script_module)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 340, in init_fn
scripted = recursive_script(orig_value)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 508, in recursive_script
return create_script_module(nn_module, infer_methods_to_compile(nn_module))
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 308, in create_script_module
return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 362, in create_script_module_impl
create_methods_from_stubs(concrete_type, stubs)
File “/home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 268, in create_methods_from_stubs
concrete_type._create_methods(defs, rcbs, defaults)
RuntimeError:
Module ‘MultiheadAttention’ has no attribute ‘q_proj_weight’ :
at /home/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/activation.py:771:30
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
~~~~~~~~~~~~~~~~~~ <— HERE
v_proj_weight=self.v_proj_weight)
else:

This was brought up here but the fix here never landed due to backwards-compatibility issues, I can try to get it through soon.

ok, thanks! I tried the statement ‘self.register_parameter(‘q_proj_weight’, None)’,but it seems doesn’t work。

It turns out there are some other fixes needed, see these two PRs for details (#28555 and #28561). If you need it now you can check out #28561 and build from source, otherwise it should be in the nightly package in a few days.

Thx a lot. it’s really helpful to me.