Multiple parallel networks

Hello. I have a network that has a few separate encoding branches, the latent features produced by which are combined towards the end. Something like:

self.encode_1 = Encode()
self.encode_2 = Encode()
self.encode_3 = Encode()

def forward(self, x):
y1 = self.encode_1(x)
y2 = self.encode_2(x)
y3 = self.encode_3(x)
y_final = y1 + y2 + y3

This works fine, but I don’t like the fact that I have hard-coded 3 branches. What can I do to allow for a variable number of branches?


You can do:

nb_branches = 3

# Assuming you are in a nn.Module
self.encoders = nn.ModuleList()
for i in range(nb_branches):

def forward(self, x):
  y_final = 0
  for encoder in self.encoders:
    y_final += encoder(x)