I wrote a simple wrapper to add special methods to a given neural network.
While the below implementation works well for general objects e.g. strings, lists etc. I get a RecursionError when applying it to a torch.nn.Module
. It seems that in the latter case the call self.instance
inside the __getattr__
method is unsuccessful which hence falls back on __getattr__
again leading to the infinite loop (I also tried self.__dict__['instane']
without luck).
What is going on here? What is the correct way of doing what I want?
import torch
class MyWrapper(torch.nn.Module):
def __init__(self, instance):
super().__init__()
self.instance = instance
def __getattr__(self, name):
print("trace", name)
return getattr(self.instance, name)
# Working example
obj = "test string"
obj_wrapped = MyWrapper(obj)
print(obj_wrapped.split(" ")) # trace split\n ['test', 'string']
# Failing example
net = torch.nn.Linear(12, 12)
net.test_attribute = "hello world"
b = MyWrapper(net)
print(b.test_attribute) # RecursionError: maximum recursion depth exceeded
b.instance # RecursionError: maximum recursion depth exceeded