I’ve been trying to run torch.jit.script(rpc_model)
on the RNN parallel model described at Pytorch’s RPC tutorial, but have been hitting issues.
My attempt includes adding script_model = torch.jit.script(model)
tight before training loop and comment out the loop, as I dont need to train the model.
The first error I got is related to TorchScript not supportting *args
and **kwargs
from _remote_method
and _call_method
. In an attempt to bypass this error, I have changed RNNModel.forward to a simple emb = rpc_sync(self.emb_table_rref.owner(), EmbeddingTable.forward, args=input)
(and commented out all other calls to local and remote layers) just to see whether this layer could be converted to torchscript. It failed with the following a similar issue, but not due to any public facing dictionary: torch.jit.frontend.NotSupportedError: keyword-arg expansion is not supported:
. Here is the full stack
$ mpirun -np 2 python rnn_jit.py
Traceback (most recent call last):
File “rnn_jit.py”, line 156, in
run_worker()
File “rnn_jit.py”, line 145, in run_worker
_run_trainer()
File “rnn_jit.py”, line 116, in _run_trainer
traced = torch.jit.script(model)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/init.py”, line 1261, in script
return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 305, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 361, in create_script_module_impl
create_methods_from_stubs(concrete_type, stubs)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 279, in create_methods_from_stubs
concrete_type._create_methods(defs, rcbs, defaults)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 568, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/init.py”, line 1290, in script
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 568, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/init.py”, line 1287, in script
ast = get_jit_def(obj)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 173, in get_jit_def
return build_def(ctx, py_ast.body[0], type_line, self_name)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 206, in build_def
build_stmts(ctx, body))
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 129, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 129, in
stmts = [build_stmt(ctx, s) for s in stmts]
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 181, in call
return method(ctx, node)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 363, in build_If
build_stmts(ctx, stmt.body),
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 129, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 129, in
stmts = [build_stmt(ctx, s) for s in stmts]
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 181, in call
return method(ctx, node)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 288, in build_Assign
rhs = build_expr(ctx, stmt.value)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 181, in call
return method(ctx, node)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 464, in build_Call
raise NotSupportedError(kw_expr.range(), ‘keyword-arg expansion is not supported’)
torch.jit.frontend.NotSupportedError: keyword-arg expansion is not supported:
File “/opt/conda/lib/python3.7/site-packages/torch/distributed/rpc/api.py”, line 476
if qualified_name is not None:
fut = _invoke_rpc_builtin(dst_worker_info, qualified_name, rf, *args, **kwargs)
~~~~~~ <--- HERE
elif isinstance(func, torch.jit.ScriptFunction):
fut = _invoke_rpc_torchscript(dst_worker_info.name, func, args, kwargs)
‘rpc_sync’ is being compiled since it was called from ‘RNNModel.forward’
File “rnn_jit.py”, line 68
# pass input to the remote embedding table and fetch emb tensor back
# emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
emb = rpc_sync(self.emb_table_rref.owner(), EmbeddingTable.forward, args=input)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <— HERE
# output, hidden = self.rnn(emb, hidden)
# pass output to the rremote decoder and get the decoded output back
Primary job terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was: