I have a use case where we need independent MLPs acting on a features in parallel. This code works but it seems that there may be a more efficient way to do this in Pytorch.
Note that in this case we have input which has dimensions
batch x feature x latent.
In normal use the one MLP would be specified. This would be used on all the features.
Another normal use is to flatten the
feature x latent dimensions and run that through one MLP. This performs mixing across the features.
But my use case is unusual and different. It calls for a separate MLP to be trained for each of the 4 features and for them to run in parallel. The following code works and does what it should. It uses the
Sequential to define the independent MLPs. But in the forward pass it requires a bit of looping which feels inefficient to me.
Is there a better and more efficient way of doing this in Pytorch? I’ve had a dig around and can’t find any helpful pre-built layers or discussions on this. Much appreciated. Thanks. Matt
class IndependentMLPs(torch.nn.Module): """ A set of parallel MLPs working independently on each feature to transform each feature in a three layer FF network. Input dimensions = batch x feature x latent Separate MLP of Linear(latent x latent) layers operating on each of the features INDEPENDENTLY """ def __init__(self, feature_dim=4, latent_dim=32): super().__init__() self.feature_layers = torch.nn.ModuleList([ torch.nn.Sequential( # Three layer FF network torch.nn.Linear(latent_dim, latent_dim), torch.nn.ReLU(), torch.nn.Linear(latent_dim, latent_dim), torch.nn.ReLU(), torch.nn.Linear(latent_dim, latent_dim), torch.nn.ReLU(), ) for i in range(feature_dim)] ) def forward(self, x): """ Input X has the dimensions batch x feature (4) x latent (32) """ outputs =  for i, layer in enumerate(self.feature_layers): in_ = x[:, i, :] # Input from feature i out = layer(in_) out = out.reshape(out.shape, 1, out.shape) outputs.append(out) x = torch.cat(outputs, axis=1) # Rebuild transformed x return x