How to create Stacked network?

I am bit puzzled on how can I create a model where there is a block is just repeated multiple times and are in parallel to each other? The inputs are different to each block.

Task:
I want to create a classification model that should look like the image. There is a Zone block which is repeated multiple times. All these zone blocks have exact same parameters - they are just copies. But they have different inputs. Mostly, the “y_x” input to each of the block. The output of all of them will go to a softmax layer to give the probability to choose one of those block.

Current Solution:
My solution is to use the ModuleList as shown below.

class ZoneBlock(nn.Module):
    def __init__(self,
                 input_dim:int,
                 num_neurons:int,
                 num_hlayers:int):
        super().__init__()

        #### MODEL #####
        seq = []

        # Input Layers
        seq += [nn.Linear(input_dim, num_neurons, dtype=torch.float32), nn.ReLU()]
        # Hidden Layers
        for n in range(num_hlayers):
            # Second Hidden Layers
            seq += [nn.Linear(num_neurons, num_neurons, dtype=torch.float32), nn.ReLU()]
        # Output layer.
        seq += [nn.Linear(num_neurons, 1, dtype=torch.float32)]

        #### MODEL ####
        self.block = nn.Sequential(*seq)

    def forward(self, data_input):
        """Apply the model to the input"""
        data = self.block(data_input)
        return data

The I stack the zoneBlock like this.

class Stacked(nn.Module):
    def __init__(self, n_zones:int, input_dim:int,
                               num_neurons=16, num_hlayers=1):
        super().__init__()
       
        #### Zone Block #####
        zone = ZoneBlock(input_dim= input_dim,
                         num_neurons=num_neurons,
                         num_hlayers=num_hlayers)      
        
        #### MODEL ####
        self.zoneSet = nn.ModuleList([zone for z in range(n_zones)])
    
    def forward(self, X):
        # Shape of X: [n_ind x n_zone  x n_attrb]

        # List for taking the logits for each zone
        zone_outputs = list()

        # Reshaping the input to [n_zone x n_ind x n_attrb] to loop eaisly
        x_reshaped = torch.reshape(X, (X.shape[1], X.shape[0], X.shape[2]))

        # Loop across all zones
        for idx, dataIn in enumerate(x_reshaped):
            zone_outputs.append(torch.flatten(self.zoneSet[idx](dataIn)))
        
        ## Return
        # Output Shape: [n_ind x n_zone]
        output = torch.stack(zone_outputs).transpose(0, 1)       
        return output

Issue:
Is there better way to do it? I think because I am using a for loop in the forward, the training is quite slow there.

Thanks!

Consider using nn.ModuleList for efficient stacking of identical blocks. To improve training speed, you can parallelize computation using torch.cat instead of a loop in the forward pass.