Does TorchScript support RPC-based models?

I’ve been trying to run torch.jit.script(rpc_model) on the RNN parallel model described at Pytorch’s RPC tutorial, but have been hitting issues.

My attempt includes adding script_model = torch.jit.script(model) tight before training loop and comment out the loop, as I dont need to train the model.

The first error I got is related to TorchScript not supportting *args and **kwargs from _remote_method and _call_method. In an attempt to bypass this error, I have changed RNNModel.forward to a simple emb = rpc_sync(self.emb_table_rref.owner(), EmbeddingTable.forward, args=input) (and commented out all other calls to local and remote layers) just to see whether this layer could be converted to torchscript. It failed with the following a similar issue, but not due to any public facing dictionary: torch.jit.frontend.NotSupportedError: keyword-arg expansion is not supported:. Here is the full stack

$ mpirun -np 2 python rnn_jit.py
Traceback (most recent call last):
File “rnn_jit.py”, line 156, in
run_worker()
File “rnn_jit.py”, line 145, in run_worker
_run_trainer()
File “rnn_jit.py”, line 116, in _run_trainer
traced = torch.jit.script(model)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/init.py”, line 1261, in script
return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 305, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 361, in create_script_module_impl
create_methods_from_stubs(concrete_type, stubs)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 279, in create_methods_from_stubs
concrete_type._create_methods(defs, rcbs, defaults)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 568, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/init.py”, line 1290, in script
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 568, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/init.py”, line 1287, in script
ast = get_jit_def(obj)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 173, in get_jit_def
return build_def(ctx, py_ast.body[0], type_line, self_name)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 206, in build_def
build_stmts(ctx, body))
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 129, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 129, in
stmts = [build_stmt(ctx, s) for s in stmts]
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 181, in call
return method(ctx, node)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 363, in build_If
build_stmts(ctx, stmt.body),
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 129, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 129, in
stmts = [build_stmt(ctx, s) for s in stmts]
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 181, in call
return method(ctx, node)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 288, in build_Assign
rhs = build_expr(ctx, stmt.value)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 181, in call
return method(ctx, node)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/frontend.py”, line 464, in build_Call
raise NotSupportedError(kw_expr.range(), ‘keyword-arg expansion is not supported’)
torch.jit.frontend.NotSupportedError: keyword-arg expansion is not supported:
File “/opt/conda/lib/python3.7/site-packages/torch/distributed/rpc/api.py”, line 476

if qualified_name is not None:
    fut = _invoke_rpc_builtin(dst_worker_info, qualified_name, rf, *args, **kwargs)
                                                                            ~~~~~~ <--- HERE
elif isinstance(func, torch.jit.ScriptFunction):
    fut = _invoke_rpc_torchscript(dst_worker_info.name, func, args, kwargs)

‘rpc_sync’ is being compiled since it was called from ‘RNNModel.forward’
File “rnn_jit.py”, line 68
# pass input to the remote embedding table and fetch emb tensor back
# emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
emb = rpc_sync(self.emb_table_rref.owner(), EmbeddingTable.forward, args=input)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <— HERE
# output, hidden = self.rnn(emb, hidden)
# pass output to the rremote decoder and get the decoded output back


Primary job terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.


mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

Process name: [[29685,1],1]
Exit code: 1

Hey @Thiago.Crepaldi

TorchScript integration with RPC is still experimental, and we are working on closing the gaps. Currently, in v1.5, applications can run TorchScript functions using RPC, e.g., rpc_sync/rpc_async/remote(to, my_script_func, args=(...)). But, within a script function only rpc_async can be called. See the test below:

Could you please try if using rpc_async works for you?

The issue below is tracking the progress on adding native RemoteModule support. Please feel free to comment there.

Thanks for the quick reply, Shen Li. By using rpc_async, I got similar error:

$ mpirun -np 2 python rnn_jit.py
Traceback (most recent call last):
File “rnn_jit.py”, line 156, in
run_worker()
File “rnn_jit.py”, line 145, in run_worker
_run_trainer()
File “rnn_jit.py”, line 116, in _run_trainer
script_model = torch.jit.script(model)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/init.py”, line 1261, in script
return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 305, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 361, in create_script_module_impl
create_methods_from_stubs(concrete_type, stubs)
File “/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 279, in create_methods_from_stubs
concrete_type._create_methods(defs, rcbs, defaults)
RuntimeError:
rpc_async(dst_worker_name, user_callable, args, kwargs)does not support kwargs yet:
File “rnn_jit.py”, line 68
# pass input to the remote embedding table and fetch emb tensor back
# emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
emb = rpc_async(self.emb_table_rref.owner(), EmbeddingTable.forward, args=input)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <— HERE
# output, hidden = self.rnn(emb, hidden)
# pass output to the rremote decoder and get the decoded output back

Line 68 refers to the rpc_async call inside the forward method
emb = rpc_async(self.emb_table_rref.owner(), EmbeddingTable.forward, args=input)

This script is a repro for the issue

import torch
import os
import torch.distributed as dist
from torch.distributed.rpc import rpc_sync
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed.rpc as rpc
from torch import optim #XXX include this
from torch.distributed.optim import DistributedOptimizer #XXX include this
from torch.distributed.rpc import RRef, rpc_async, remote
import torch.distributed.autograd as dist_autograd #XXX include this

def get_local_rank():
    return int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])

def _parameter_rrefs(module):
    param_rrefs = []
    for param in module.parameters():
        param_rrefs.append(RRef(param))
    return param_rrefs

class EmbeddingTable(nn.Module):
    r"""
    Encoding layers of the RNNModel
    """
    def __init__(self, ntoken, ninp, dropout):
        super(EmbeddingTable, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp).cuda()
        self.encoder.weight.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        return self.drop(self.encoder(input.cuda())).cpu() # XXX: extra ')'


class Decoder(nn.Module):
    def __init__(self, ntoken, nhid, dropout):
        super(Decoder, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-0.1, 0.1)

    def forward(self, output):
        return self.decoder(self.drop(output))

class RNNModel(nn.Module):
    def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()

        # setup embedding table remotely
        self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
        # setup LSTM locally
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        # setup decoder remotely
        self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))

    def forward(self, input, hidden):
        # pass input to the remote embedding table and fetch emb tensor back
        # emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input) # Original call
        emb = rpc_async(self.emb_table_rref.owner(), EmbeddingTable.forward, args=input) # adapted call

        output, hidden = self.rnn(emb, hidden)

        # pass output to the rremote decoder and get the decoded output back
        # decoded = _remote_method(Decoder.forward, self.decoder_rref, output) # Original call
        decoded = rpc_async(self.decoder_rref.owner(), Decoder.forward, args=output) # adapted call
        return decoded, hidden
    def parameter_rrefs(self):
        remote_params = []
        # get RRefs of embedding table
        # remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref)) # Original call
        remote_params.extend(rpc_async(_parameter_rrefs, self.emb_table_rref))

        # create RRefs for local parameters
        remote_params.extend(_parameter_rrefs(self.rnn))

        # get RRefs of decoder
        # remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref)) # Original call
        remote_params.extend(rpc_async(_parameter_rrefs, self.decoder_rref)) # Adapted call
        return remote_params

def _run_trainer():
    ntoken = 10
    ninp = 2
    nhid = 3
    nlayers = 4
    model = RNNModel('ps', ntoken, ninp, nhid, nlayers) # XXX: no rnn.
    script_model = torch.jit.script(model)
    print(script_model)

def run_worker():
    world_size=2
    rank=get_local_rank()
    os.environ['MASTER_ADDR'] = '10.123.134.28'
    os.environ['MASTER_PORT'] = '21234'
    if rank == 1:
        rpc.init_rpc("trainer", rank=rank, world_size=world_size)
        _run_trainer()
    else:
        rpc.init_rpc("ps", rank=rank, world_size=world_size)
        # parameter server do nothing
        pass
    # block until all rpcs finish
    rpc.shutdown()

if __name__=="__main__":
    run_worker()