Is there a way to use list of indices to simultaneously access the modules of
nn.ModuleList
in python?
I am working with pytorch ModuleList
as described below,
decision_modules = nn.ModuleList([nn.Linear(768, 768) for i in range(10)])
Our input data is of the shape x=torch.rand(32,768)
. Here 32
is the batch size and 768
is the feature dimension.
Now for each input data point in a minibatch of 32
datapoints, we want to select 4
decision modules from the list of decision_modules
. The 4
decision engines from decision_engine
are selected using an index list as explained below.
I have a index matrix of dimensions ind
. The ind
matrix is of dimension torch.randint(0,10,(4,4))
.
I want to us a solution without use of loops as loops slows down the xecution significantly.
But the following code throws and error.
import torch
import torch.nn as nn
linears = nn.ModuleList([nn.Linear(768, 768) for i in range(10)])
ind=torch.randint(0,10,(4,4))
input=torch.rand(32,768)
out=linears[ind](input)
The following error was observed
File ~\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py:334, in ModuleList.getitem(self, idx)
332 return self.class(list(self._modules.values())[idx])
333 else:
→ 334 return self._modules[self._get_abs_string_index(idx)]File ~\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py:314, in ModuleList._get_abs_string_index(self, idx)
312 def _get_abs_string_index(self, idx):
313 “”“Get the absolute index for the list of modules.”“”
→ 314 idx = operator.index(idx)
315 if not (-len(self) <= idx < len(self)):
316 raise IndexError(f"index {idx} is out of range")TypeError: only integer tensors of a single element can be converted to an index
Any help will be highly useful.