Low performance when iterating over ModuleDict in forward()

I’m currently trying to find a faster way to process many (~100) categorical embeddings in the forward pass.

As the features have distinct vocabulary sizes and embedding dimensions, each requires its own embedding layer. Currently, I initialize them as nn.ModuleDict (emb_model in the snippet below), where the keys correspond to the names of the categorical features and the values correspond to the embedding layers (each with a weight tensor of shape (size_vocab, d_emb)).

In the forward pass, I loop over the ~100 .items() of the ModuleDict and provide the corresponding features as inputs (X is a dictionary with the same keys as the ModuleDict) to each layer:

def forward(self, X):

    concat_list = []
    for key, layer in self.emb_model.items():
        x = layer(X[key])
        concat_list.append(x)

    x = torch.cat(concat_list, dim=1)
...

However, this solution is about 3x slower than a equivalent model using the Keras/TF functional API, which I presume is due to the loop being executed in every forward pass.

Is there any better way to handle large numbers of categorical features in PyTorch?

This is kinda hacky, but as long as your GPU memory and utilization can handle it, couldnt you make each iteration into its own multiprocessing worker? That would send concurrent forward passes to the GPU allowing you to execute everything at once. Since this doesn’t appear to rely on the previous iteration output, it could be ran concurrently

Thanks for the tip, I will try that! Although ideally I’d prefer to solve it without relying on multiprocessing workers.

What I have tried is defining an MultiEmbedding class as:

class MultiEmbedding(nn.Module):
    def __init__(self, modules):
        super().__init__()
        self.emb = modules
    
    def forward(self, X):
        x = torch.cat([m(X[:,i].flatten()) for i, m in enumerate(self.emb)], dim=1)
        return x

which is initialized with a nn.ModuleList. In the initialization of the main module, I then convert this submodule to TorchScript

self.emb = torch.jit.script(MultiEmbedding(modules))

and I then have no more loop in the forward() of the main module. Instead of passing multiply integer arrays to the different embedding layers, all integer-encoded categorical variables are combined in one array and passed directly to the TorchScript MultiEmbedding.

This yields a small performance increase, but the TF/Keras equivalent is still more than 2x faster.

The only other idea I had would be to (in the future) combine vmap (the integer indices) and NestedTensor (the various combined embedding matrices) in a new nn.Module …