All the different predictors have the same structure so I did something like that:
self.predictors = nn.ModuleDict()
for i in range(n):
self.predictors["predictor_" + str(i)] = nn.Sequential(
<< STRUCTURE OF THE PREDICTORS >>
)
And my forward will be something like:
def forward(self, x):
#Encoder
x = self.encoder(x)
#Predictors
output = torch.cat(tuple(
self.predictors["predictor_" + str(i)](x) for i in range(n))
)
Yes, the code looks generally alright.
One minor suggestion: if the first module in predictors might apply an inplace operation, you would need to clone the input via x.clone() before passing it to these predictors. Otherwise it’ll be changed inplace and the following predictors would get the already manipulated input.