Is there a way to use list of indices to simultaneously access the modules of `nn.ModuleList` in python?

Is there a way to use list of indices to simultaneously access the modules of
nn.ModuleList in python?

I am working with pytorch ModuleList as described below,

decision_modules = nn.ModuleList([nn.Linear(768, 768) for i in range(10)])

Our input data is of the shape x=torch.rand(32,768). Here 32 is the batch size and 768 is the feature dimension.

Now for each input data point in a minibatch of 32 datapoints, we want to select 4 decision modules from the list of decision_modules. The 4 decision engines from decision_engine are selected using an index list as explained below.

I have a index matrix of dimensions ind. The ind matrix is of dimension torch.randint(0,10,(4,4)).

I want to us a solution without use of loops as loops slows down the xecution significantly.

But the following code throws and error.

import torch
import torch.nn as nn

linears = nn.ModuleList([nn.Linear(768, 768) for i in range(10)])
ind=torch.randint(0,10,(4,4))
input=torch.rand(32,768)

out=linears[ind](input)

The following error was observed

File ~\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py:334, in ModuleList.getitem(self, idx)
332 return self.class(list(self._modules.values())[idx])
333 else:
→ 334 return self._modules[self._get_abs_string_index(idx)]

File ~\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py:314, in ModuleList._get_abs_string_index(self, idx)
312 def _get_abs_string_index(self, idx):
313 “”“Get the absolute index for the list of modules.”“”
→ 314 idx = operator.index(idx)
315 if not (-len(self) <= idx < len(self)):
316 raise IndexError(f"index {idx} is out of range")

TypeError: only integer tensors of a single element can be converted to an index

Any help will be highly useful.