Vectorize nn.ModuleDict


For the past few days, I’ve been trying to figure out how can I speed up my multi-modal model (a feature extract for reinforcement learning that takes as input images and robot states). I know that torch.func provides a function stack_module_states, however, the models in the list have to have the same architecture. I wonder if there is a way to improve the model represented by the nn.ModuleDict to receive a dictionary of tensors, run all models in parallel and then collect the results. Thanks in advance for any advice!