A more efficient way to create independent parallel MLP layers in Pytorch

I have a use case where we need independent MLPs acting on a features in parallel. This code works but it seems that there may be a more efficient way to do this in Pytorch.

Note that in this case we have input which has dimensions batch x feature x latent.

In normal use the one MLP would be specified. This would be used on all the features.

Another normal use is to flatten the feature x latent dimensions and run that through one MLP. This performs mixing across the features.

But my use case is unusual and different. It calls for a separate MLP to be trained for each of the 4 features and for them to run in parallel. The following code works and does what it should. It uses the ModuleList and Sequential to define the independent MLPs. But in the forward pass it requires a bit of looping which feels inefficient to me.

Is there a better and more efficient way of doing this in Pytorch? I’ve had a dig around and can’t find any helpful pre-built layers or discussions on this. Much appreciated. Thanks. Matt

class IndependentMLPs(torch.nn.Module):
    A set of parallel MLPs working independently on each feature to transform 
    each feature in a three layer FF network.
    Input dimensions = batch x feature x latent

    Separate MLP of Linear(latent x latent) layers operating on each of the features INDEPENDENTLY

    def __init__(self, feature_dim=4, latent_dim=32):

        self.feature_layers = torch.nn.ModuleList([
                # Three layer FF network
                torch.nn.Linear(latent_dim, latent_dim),
                torch.nn.Linear(latent_dim, latent_dim),
                torch.nn.Linear(latent_dim, latent_dim),
            ) for i in range(feature_dim)]

    def forward(self, x):
        """ Input X has the dimensions batch x feature (4) x latent (32) """

        outputs = []
        for i, layer in enumerate(self.feature_layers):
            in_ = x[:, i, :]  # Input from feature i
            out = layer(in_)
            out = out.reshape(out.shape[0], 1, out.shape[1])
        x = torch.cat(outputs, axis=1)  # Rebuild transformed x

        return x