When I tried torch.compile()
for my project, it raised some errors. One of the error is that my model has a hook that checks the module name, and torch.compile()
change the module name with an additional prefix _orig_mod
.
The way I debug my model:
...
model = torch.compile(model)
for p in model.modules():
if len(list(p.children()))==0:
print(p.__name__) # I store module name to `module.__name__`
...
How I store my module name, before registering the hook:
...
for name, module in model.named_modules():
if len(list(module.children()))==0:
module.__name__ = name
module.register_forward_pre_hook(self.myhook())
...
In this situation, do I need to change the way my hook work, or is there any better workaround to avoid this problem?
UPDATE
Somehow I missed that module and hooks are not fully supported yet: PyTorch 2.0 | PyTorch