I have to compose a Parent Module with a child sub module, for OO compatibility, I have to let the Parent object access child properties directly. Below is a reproducible example where I comment the output inside the code.
I have two things to ask:
-
I found that it seems simply defining submodules inside a module’s initializer also works, by “works” i mean both the parent module and the submodule all get trained. So my first question is simplying adding Submodule as a normal member is equivalent compared to
self.add_module("child", SubModule())
in terms of training? In this post, it is not explained: When to use add_module function? -
I tried the following code, but it all failed to access child module’s attribute from the parent object (failure message in comment below ), what is the right way to do that?
import torch.nn as nn
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.h = 3
class Parent(nn.Module):
def __init__(self):
super(Parent, self).__init__()
self.child = SubModule()
self.m = 1
class ParentControl(nn.Module):
def __init__(self):
super(ParentControl, self).__init__()
self.child = SubModule()
self.m = 1
def __getattr__(self, name):
try:
return self.__dict__[name]
except KeyError:
child = self.child
return child.__getattr__(name)
class ParentControl2(nn.Module):
def __init__(self):
super(ParentControl2, self).__init__()
self.child = SubModule()
self.m = 1
def __getattr__(self, name):
try:
return self.__getattr___[name]
except AttributeError:
child = self.child
return child.__getattr__(name)
pp = Parent()
"m" in pp.__dict__.keys() #True
"child" in pp.__dict__.keys() # False
pp.child # SubModule()
pp.m # 1
pp.h # AttributeError: 'Parent' object has no attribute 'h'
pp1 = ParentControl()
pp1.child # RecursionError: maximum recursion depth exceeded while calling a Python object
pp2 = ParentControl2()
pp2.__getattr__("child") # RecursionError: maximum recursion depth exceeded while calling a Python object
pp2.h # RecursionError: maximum recursion depth exceeded while calling a Python object