Hello! I have some code, and it works just fine. But it’s clunky looking, and I suspect there’s a cleaner and more efficient way.
What I’m trying to do is to run a certain nn.Sequential() pipeline on (say) an image, and have (for example) 10 copies of this pipeline, all with their own weights. Think of the copies of the pipeline as each looking for a different feature in the input image.
Once the parallel sections of the pipeline have run, the results from the 10 copies are merged together again. I’m doing that like this:
class MyParallelNetworkStage(nn.Module): def __init__(self, num_parallel_copies): super().__init__() self.parallel_pipelines = nn.ModuleList([nn.Sequential( nn.Linear(1000, 500), nn.ReLU(), nn.Linear(500, 5), nn.ReLU(), ) for i in range(num_parallel_copies)]) self.merge_results_together = nn.Sequential( nn.Linear(5 * num_parallel_copies, 1), ) def forward(self, x): parallel_results = None for each_parallel_instance in self.parallel_pipelines: this_result = each_parallel_instance(x) if parallel_results is None: parallel_results = this_result else: parallel_results = torch.cat((parallel_results, this_result), 1) single_result = self.merge_results_together(parallel_results) return single_result
But, this has a ‘bad code smell’ to it, especially that bit where I torch.cat() each parallel pipeline’s results together. Is there a different way I could write this more cleanly and efficiently? I’ve seen people calling for a nn.Parallel stage as part of pytorch, but those requests have been dismissed as not being necessary, which makes me think there’s a good way to write these parallel sections with the existing torch API that I might be missing.