List of Lists with nn.Module()


I want to properly register a list of nn.Modules() that I structure as a list of list because they represent connections in a matrix. How is the best way to register a list of lists modules in pytorch?. As far as I understand from the doc, nn.ModuleList() requires an interable containing modules, hence not sure if it is useful for my purposes.

I was thinking in creating a nn.Module utility that contains each of the lists, and then create a nn.ModuleList with a list of these modules.

Thanks in advance.

ModuleList itself is a module, so you can just do something like:

nn.ModuleList(nn.ModuleList(... for _ in range(5)) for _ in range(5)) to create a 5x5 “grid” of modules. If you need a more complex structure, you can just do nn.ModuleList([nn.ModuleList(...), ...]) in any format you need.

However, this makes me think there might be a better way to do what you want. Could you be a bit more explicit?

1 Like