Running RNNs in parallel on single GPU

First, I have very little experience with jit, so this may be an issue with my knowledge there - but I’ve scoured every resource I can find over the past several days to no avail.

Lets say you have a nn.ModuleList with a collection of RNNs, and you want to process the same data through each. CUDA operations are async - so we do something like this:

self.networks = nn.ModuleList([nn.RNN(input_size=self.feature_count, hidden_size=self.hidden_size) for i in range(self.count_rnn)])
all_results = []
for rnn in self.networks:
    res, _ = rnn(input)
    all_results.append(res)
all_results = torch.stack(all_results, dim=-1)

Although CUDA operations are async, this simple for loop to collect each RNN output increases the iteration loop time linearly as self.count_rnn increases (the length of the ModuleList).

Thus, an additional attempt must be made to try to parallelize the process; so we move to torch.jit.fork. Consider the following implementation:

class ParallelModuleList(nn.ModuleList):
    def __init__(self):
        super(ParallelModuleList, self).__init__()

    def forward(self, x):
        futures : List[torch.futures.Future[torch.Tensor]] = []
        for module in self:
            future : torch.futures.Future[torch.Tensor] = torch.jit.fork(module, x)
            futures.append(future)
        output : List[torch.Tensor] = []
        for future in futures:
            output.append(torch.jit.wait(future))
        return output

and used like this

self.networks : nn.ModuleList = torch.jit.script(ParallelModuleList([nn.RNN(input_size=self.feature_count, hidden_size=self.hidden_size) for i in range(self.count_rnn)]))

You may wonder why not list comprehension and why all the typing - unfortunately I kept running into errors about the futures not being typed, so I’ve added all this separation and typing to avoid those. I was also unable to utilize torch.jit.wait_all - even with the typing, it continued to spit out errors about unset types.

Back to the results of this implementation - when it reaches the step to create the optimizer…

self.optimizer = torch.optim.SGD(params=self.embeddings_network.parameters(), lr=config.learning_rate, momentum=config.momentum)

…it errors with “ValueError: optimizer got an empty parameter list”. The ParallelModuleList is somehow not recognized as a ModuleList (even when typed as such), so there are no parameters (this works fine in the situation above and with other networks, so its something specific to this implementation).

Thus, this implementation:

class ParallelModules(nn.Module):
    def __init__(self, module_list):
        super(ParallelModules, self).__init__()
        self.module_list = nn.ModuleList(module_list)

    def forward(self, x):
        futures : List[torch.futures.Future[torch.Tensor]] = []
        for module in self.module_list:
            future : torch.futures.Future[torch.Tensor] = torch.jit.fork(module, x)
            futures.append(future)
        output : List[torch.Tensor] = []
        for future in futures:
            output.append(torch.jit.wait(future))
        return output

This seems to fail even earlier, with an error the even a google search brings up nothing about:

File "/sandbox/parallelRNNs/parallelRNNs.py", line 36, in __init__
    self.networks : nn.Module = torch.jit.script(ParallelModules([nn.RNN(input_size=self.feature_count, hidden_size=self.hidden_size) for i in range(self.count_rnn)]))
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_script.py", line 1097, in script
    obj, torch.jit._recursive.infer_methods_to_compile
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py", line 412, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py", line 474, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_script.py", line 497, in _construct
    init_fn(script_module)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py", line 452, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py", line 474, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_script.py", line 497, in _construct
    init_fn(script_module)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py", line 452, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py", line 424, in create_script_module_impl
    method_stubs = stubs_fn(nn_module)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py", line 689, in infer_methods_to_compile
    stubs.append(make_stub_from_method(nn_module, method))
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py", line 53, in make_stub_from_method
    return make_stub(func, method_name)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py", line 38, in make_stub
    ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py", line 308, in get_jit_def
    return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py", line 359, in build_def
    build_stmts(ctx, body))
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py", line 137, in build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py", line 137, in <listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py", line 331, in __call__
    return method(ctx, node)
  File "/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py", line 577, in build_AnnAssign
    raise UnsupportedNodeError(ctx, stmt, reason='without assigned value')
torch.jit.frontend.UnsupportedNodeError: annotated assignments without assigned value aren't supported:
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/rnn.py", line 274
                           self.num_layers, self.dropout, self.training, self.bidirectional)
    
        output: Union[Tensor, PackedSequence]
        ~ <--- HERE
        output = result[0]
        hidden = result[1]

I’m a bit at a loss. I’ve found multiple topics on running networks in parallel and everyone either says “it runs in parallel by default” - which is clearly not the case (at least for modules of type nn.RNN) OR they say “use nn.Sequential” - which is not appropriate OR they say “just use jit” - which seems not to work in this case either.

Does anyone have a solid, functional example of RNNs running on the same data in parallel on a single GPU in PyTorch? Or really, forget RNNs - ANY network?

I’ve posted an example and explanation here a while ago.
The CUDA kernels could run in parallel, if enough resources are free and if the CPU is fast enough to run ahead in order to schedule the next kernel launch while the first one is still being executed.