Forward tensor to module in a Modulelist efficently?

Given some nn.Modulelist object that contains K modules and a pair of tensors where the first tensor contains the data to send and the second the indices of the module in the nn.Modulelist object, how to forward the tensor to its associated module efficiently ?

The only way I’ve been able to do it is by forwarding tensor in batch to each module of the nn.Modulelist object. Is there another way that is faster?

import torch
import torch.nn as nn

input_size = 40
output_size = 40

module_list = nn.ModuleList([nn.Linear(input_size, output_size) for _ in range(number_of_module)])
t_idx = torch.LongTensor(batch_size).random_(0, number_of_module) #indice of the Modulelist oject
data = torch.rand(batch_size, input_size)

next_state = torch.rand((t_idx.size(0), output_size)).to(device)
for ixx in set(t_idx.tolist()):
    ixs = torch.where(t_idx == ixx)
    next_state[ixs] = module_list[ixx](data[ixs])