Is this way to do parallel nn stages an antipattern?

Hello! I have some code, and it works just fine. But it’s clunky looking, and I suspect there’s a cleaner and more efficient way.

What I’m trying to do is to run a certain nn.Sequential() pipeline on (say) an image, and have (for example) 10 copies of this pipeline, all with their own weights. Think of the copies of the pipeline as each looking for a different feature in the input image.

Once the parallel sections of the pipeline have run, the results from the 10 copies are merged together again. I’m doing that like this:

class MyParallelNetworkStage(nn.Module):
    def __init__(self, num_parallel_copies):
        super().__init__()

        self.parallel_pipelines = nn.ModuleList([nn.Sequential(
            nn.Linear(1000, 500),
            nn.ReLU(),
            nn.Linear(500, 5),
            nn.ReLU(),
        ) for i in range(num_parallel_copies)])

        self.merge_results_together = nn.Sequential(
            nn.Linear(5 * num_parallel_copies, 1),
        )

    def forward(self, x):
        parallel_results = None

        for each_parallel_instance in self.parallel_pipelines:
            this_result = each_parallel_instance(x)

            if parallel_results is None:
                parallel_results = this_result
            else:
                parallel_results = torch.cat((parallel_results, this_result), 1)

        single_result = self.merge_results_together(parallel_results)

        return single_result

But, this has a ‘bad code smell’ to it, especially that bit where I torch.cat() each parallel pipeline’s results together. Is there a different way I could write this more cleanly and efficiently? I’ve seen people calling for a nn.Parallel stage as part of pytorch, but those requests have been dismissed as not being necessary, which makes me think there’s a good way to write these parallel sections with the existing torch API that I might be missing.

I stripped away the image-specific parts of my earlier code to make a pure nn.Parallel stage. It seems to work, but it still has that .cat in it, so I don’t know if I’m losing efficiency due to that. What do you think?

class nn_Parallel(nn.Module):
    def __init__(self, parallel_pipe, num_copies, merge_results_pipe = None):
        super().__init__()

        self.parallel_pipes = nn.ModuleList([parallel_pipe for i in range(num_copies)])
        self.merge_results_pipe = merge_results_pipe

    def forward(self, x):
        parallel_results = None

        # Get the individual results from each parallel copy of the pipeline in turn
        for each_parallel_instance in self.parallel_pipes:
            this_result = each_parallel_instance(x)

            if parallel_results is None:
                parallel_results = this_result
            else:
                # Concatenate the parallel results into a single tensor
                parallel_results = torch.cat((parallel_results, this_result), 1) # 1 == after batch index

        # Either combine these parallel results back into a single result if a
        # merging stage has been provided, or if not then just return the
        # concatenated set of parallel results directly
        if self.merge_results_pipe is not None:
            single_result = self.merge_results_pipe(parallel_results)
            return single_result
        else:
            return parallel_results