Saving nn.Module to Parent nn.Module without Registering Paremeters

Hi,

I have a parent nn.Module parent that saves multiple custom defined child nn.Modules as attributes of which I want one to not appear in parent.children().

In other words, I want to save an nn.Module as an attribute of another nn.Module without registering its parameters to the former nn.Module. How can I do this?

Thanks!

If you are using a custom class (or just a Python object) for the children, I would guess that they won’t be registered in the parent nn.Module.

Hi, thanks for the answer!

As the modules arise systematically (and the relevant attribute is expected to be of type Module elsewhere), it would be nice to not have to wrap the child Module.

I had hoped there was some possibility to enter __setattr__ with the child Module considered to be an object of type object, or would this still make the parent Module register the child in children()?

Thanks again!

I don’t think there is a clean approach to still use an nn.Module, but disallow registering it inside the parent. You could check this code and try to see, if changing the self._modules attribute might work, but note that this is an internal attribute and I don’t know what unwanted side effects you might trigger.
Maybe you could explain your use case a bit more so that we could think about another approach?

I have the same requirement. My use case is that I want to add two versions of the same submodule, one with full functionality (metrics, additional computations for validation) and one stripped down for training only. Both share the same parameters:

class MyModule(nn.Module):
  def __init__(self, m: nn.Module):
    self.m = m
    self._m_train = m.strip() # Remove any functionality not required for training

  def forward(self, input):
    return self._m_train(input)

my_module = MyModule(m)
optimizer = torch.optim.SGD(my_module.parameters())

Fortunately, my_module.parameters() doesn’t seem to contain duplicate parameters. But it seems cleaner (at least to me) to not have both versions of the same m registered as children.

One quick and dirty hack I did was to wrap the submodule in tuple or list. For example:

class MyModule(nn.Module):
  def __init__(self, m: nn.Module):
    self._m = [m]
    self._m_train = m.strip() # Remove any functionality not required for training

  def forward(self, input):
    return self._m_train(input)
  
  # Optional
  @property
  def m(self):
    return self._m[0]

my_module = MyModule(m)
optimizer = torch.optim.SGD(my_module.parameters())

Btw this is nice when if you have a module that registers an altered version of itself as an attribute but avoid infinite recursion on module-recursing functions like .to() (for example).