I currently face the problem of wanting to use tensor indexing to select NNs from a list. What I want to do is list_of_NNs [index_tensor] (evaluate_on_this).
This does not work with lists and I could not find a way around this besides using generator expressions which is of course terribly slow. Is there a data structure of function I could find which allows me to do this efficiently?
Specifically, I want to evaluate on batches, of course.
Edit: Bad formatting…
I’m unsure if I understand the use case correctly, but do you want to create multiple modules, store them in a
list, and iterate them afterwards?
Would this example work?
def __init__(self, modules):
self.module_list = nn.ModuleList(modules)
def forward(self, x):
for module in self.module_list:
x = module(x)
modules = [
model = MyModel(modules)
x = torch.randn(1, 10)
out = model(x)
If not, could you describe your use case a bit more, please?
Sorry, my initial post had some formatting issues.
My main goal is to define multiple modules, say NN_1,…,NN_m. I then want to be able to store them in some structure, a list in my current code, say list_NN. Then, I want to “tensor index” that list, so given a tensor of indices index_tensor, I want list_NN[tensor_index], where an entry is the corresponding NN to that index.
I then want to evaluate this “tensor of NNs” on some batch of tensors.
I could not get your approach to work because ModuleList also seems to be using “basic” lists.
I also tried using vmap but it also raises, that only single integer tensors can be used for indexing.
Thanks for clarifying. The idea to use
vmap sounds reasonable and your example is also quite similar to this model ensembling example.