I was trying to build a tree like neural net, where there are nested modules in ModuleList/ModuleDict objects. However I encountered a maximum recursion bug in the top node, the ‘root’ of the tree. To make it simple, I created a minimal example to reproduce this error (Pythorch 1.2):
class TreeNode_Test(nn.Module):
def __init__(self):
super(TreeNode_Test, self).__init__()
self.nodesInLevels = nn.ModuleList([self])
myModel = TreeNode_Test()
myModel # when calling this or myModel.nodesInLevels I ll get max recursion error:
File "C:\Users\mk23\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1042, in __repr__
mod_str = repr(module)
File "C:\Users\mk23\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1042, in __repr__
mod_str = repr(module)
File "C:\Users\mk23\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1036, in __repr__
extra_repr = self.extra_repr()
RecursionError: maximum recursion depth exceeded
You model references itself when you use self in your ModuleList. So the print function that tries to print all the module that are contained in your model will run infinitely.
I don’t think this behaviour is correct, if I change the parent class to object instead of nn.Module, or the nn.ModuleList to python list(), then it will work as expected - but then it won’t work with DataParallel. As then it won’t replicate the model properly across multiple GPUs and I will end up with the dreaded tensors/parameters on different GPUs error…
It isn’t just print, (which I could avoid by not calling it), pretty much everything ends up in an infinite loop, eg module.apply(fn) which I can’t avoid using.