Are module methods other than forward preserved? I expected scale to be there in the traced_mod.
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
x = torch.relu(x)
return x
def scale(self, x):
return x * 2
mod = MyModule()
traced_mod = torch.fx.symbolic_trace(mod)
x = torch.randn(1, 512)
orig_out = mod.scale(x)
traced_out = traced_mod.scale(x)
"""
Traceback (most recent call last):
File "fx_preserve_module_methods.py", line 25, in <module>
traced_out = traced_mod.scale(x)
File "~/venv_sys/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1177, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'MyModule' object has no attribute 'scale'
"""
PyTorch version:
python -c "import torch; print(torch.__version__)"
1.10.2+cu113