I have a parent nn.Module parent that saves multiple custom defined child nn.Modules as attributes of which I want one to not appear in parent.children().
In other words, I want to save an nn.Module as an attribute of another nn.Module without registering its parameters to the former nn.Module. How can I do this?
As the modules arise systematically (and the relevant attribute is expected to be of type Module elsewhere), it would be nice to not have to wrap the child Module.
I had hoped there was some possibility to enter __setattr__ with the child Module considered to be an object of type object, or would this still make the parent Module register the child in children()?
I don’t think there is a clean approach to still use an nn.Module, but disallow registering it inside the parent. You could check this code and try to see, if changing the self._modules attribute might work, but note that this is an internal attribute and I don’t know what unwanted side effects you might trigger.
Maybe you could explain your use case a bit more so that we could think about another approach?
I have the same requirement. My use case is that I want to add two versions of the same submodule, one with full functionality (metrics, additional computations for validation) and one stripped down for training only. Both share the same parameters:
class MyModule(nn.Module):
def __init__(self, m: nn.Module):
self.m = m
self._m_train = m.strip() # Remove any functionality not required for training
def forward(self, input):
return self._m_train(input)
my_module = MyModule(m)
optimizer = torch.optim.SGD(my_module.parameters())
Fortunately, my_module.parameters() doesn’t seem to contain duplicate parameters. But it seems cleaner (at least to me) to not have both versions of the same m registered as children.
One quick and dirty hack I did was to wrap the submodule in tuple or list. For example:
class MyModule(nn.Module):
def __init__(self, m: nn.Module):
self._m = [m]
self._m_train = m.strip() # Remove any functionality not required for training
def forward(self, input):
return self._m_train(input)
# Optional
@property
def m(self):
return self._m[0]
my_module = MyModule(m)
optimizer = torch.optim.SGD(my_module.parameters())
Btw this is nice when if you have a module that registers an altered version of itself as an attribute but avoid infinite recursion on module-recursing functions like .to() (for example).