How to create independent blocks of nn pythonically

I simplify the thing but I want to build the following structure:

x -> [Encoder] -> x_encoded
x_encoded -> [predictor 1] -> y_1
x_encoded -> [predictor 2] -> y_2

x_encoded -> [predictor n] -> y_n
output = concat(y_1, …, y_n)

All the different predictors have the same structure so I did something like that:

 self.predictors = nn.ModuleDict()
 for i in range(n):
        self.predictors["predictor_" + str(i)] = nn.Sequential(
                << STRUCTURE OF THE PREDICTORS >>
            )

And my forward will be something like:

def forward(self, x):
        #Encoder
        x =  self.encoder(x)
        #Predictors
        output = torch.cat(tuple(
                      self.predictors["predictor_" + str(i)](x) for i in range(n))
                      )

Is it the right way to do it ?

Yes, the code looks generally alright.
One minor suggestion: if the first module in predictors might apply an inplace operation, you would need to clone the input via x.clone() before passing it to these predictors. Otherwise it’ll be changed inplace and the following predictors would get the already manipulated input.

1 Like

Thanks for the suggestion ! A friend one mine coded a multilinear class as well to deal with that kind of problem

import torch

class MultiLinear(torch.nn.Module):
    def __init__(self, input_size, output_size, nb_estimators, activation=torch.relu):
        super().__init__()
        self.weights = torch.nn.Parameter(torch.randn(nb_estimators, input_size, output_size))
        self.weights.requires_grad = True
        self.activation = activation    def forward(self, x_batch):
        if x_batch.ndim == 2:
            # x_batch: (batch_size, input_size)
            z = torch.einsum("ij,mjl->ilm", x_batch, self.weights)
        elif x_batch.ndim == 3:
            # x_batch: (batch_size, input_size, nb_estimators)
            z = torch.einsum("ijm,mjl->ilm", x_batch, self.weights)
        else:
            raise ValueError("input dims must be 2 or 3")
        output = self.activation(z)
        # output: (batch_size, output_size, nb_estimators)
        return output