I am bit puzzled on how can I create a model where there is a block is just repeated multiple times and are in parallel to each other? The inputs are different to each block.
Task:
I want to create a classification model that should look like the image. There is a Zone block which is repeated multiple times. All these zone blocks have exact same parameters - they are just copies. But they have different inputs. Mostly, the “y_x” input to each of the block. The output of all of them will go to a softmax layer to give the probability to choose one of those block.
Current Solution:
My solution is to use the ModuleList as shown below.
class ZoneBlock(nn.Module):
def __init__(self,
input_dim:int,
num_neurons:int,
num_hlayers:int):
super().__init__()
#### MODEL #####
seq = []
# Input Layers
seq += [nn.Linear(input_dim, num_neurons, dtype=torch.float32), nn.ReLU()]
# Hidden Layers
for n in range(num_hlayers):
# Second Hidden Layers
seq += [nn.Linear(num_neurons, num_neurons, dtype=torch.float32), nn.ReLU()]
# Output layer.
seq += [nn.Linear(num_neurons, 1, dtype=torch.float32)]
#### MODEL ####
self.block = nn.Sequential(*seq)
def forward(self, data_input):
"""Apply the model to the input"""
data = self.block(data_input)
return data
The I stack the zoneBlock like this.
class Stacked(nn.Module):
def __init__(self, n_zones:int, input_dim:int,
num_neurons=16, num_hlayers=1):
super().__init__()
#### Zone Block #####
zone = ZoneBlock(input_dim= input_dim,
num_neurons=num_neurons,
num_hlayers=num_hlayers)
#### MODEL ####
self.zoneSet = nn.ModuleList([zone for z in range(n_zones)])
def forward(self, X):
# Shape of X: [n_ind x n_zone x n_attrb]
# List for taking the logits for each zone
zone_outputs = list()
# Reshaping the input to [n_zone x n_ind x n_attrb] to loop eaisly
x_reshaped = torch.reshape(X, (X.shape[1], X.shape[0], X.shape[2]))
# Loop across all zones
for idx, dataIn in enumerate(x_reshaped):
zone_outputs.append(torch.flatten(self.zoneSet[idx](dataIn)))
## Return
# Output Shape: [n_ind x n_zone]
output = torch.stack(zone_outputs).transpose(0, 1)
return output
Issue:
Is there better way to do it? I think because I am using a for loop in the forward, the training is quite slow there.
Thanks!