Way to create Module that builds on top of other, swappable Modules?

To make experiments require minimal code changes, I have multiple Modules that obey the exact same forward call. Let’s call these A, B, C, D, E.

Though they all obey the same forward call, each can be constructed with different arguments, since each has different options. In this case, how can we create another Module that uses any pair of these modules, so say A and B, or A and C, or whatever?

Option 1

Build ABModule, ACModule, etc. Would lead to tons of boilerplate + other repeated code.

Option 2

class HigherLevelModule(torch.nn.Module):
    def __init__(self, Module1, Module2):
        self.module1 = Module1(...)

Bad option because Module1 and Module2 can be constructed with module-specific arguments.

Option 3

class HigherLevelModule(torch.nn.Module):
    def __init__(self, module1, module2):
        self.module1 = module1

This seems okay but I don’t see this pattern anywhere in PyTorch.

Will Option 3 break any of the underlying Module functionality? And in either case is there a better way to do this?

Edit: Note that each module has its own parameters.

Option 3 seems to work. Just be careful that module1 and module2 aren’t being modified by anything outside.

I think if you wanted to make it nicer, instead of passing in an initialized module, you could pass in a function that takes no arguments but initializes a module. Something like:
lambda: Module1(...), but that’s just personal preference.