Hello everyone,
I would like to prune my model and then script it but I receive the following error: AttributeError: 'LnStructured' object has no attribute '__name__'
Can we prune a model and then script it?
I see that LnStructured
actually add a forward pre_hook but jit.script
can’t resolve its name.
cf:
Traceback (most recent call last):
File "src/model_optimizer.py", line 226, in <module>
script_module(pruned_model)
File "src/model_optimizer.py", line 20, in script_module
module = jit.script(module)
File "/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 1257, in script
return torch.jit._recursive.create_script_module(
File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 451, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 513, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 587, in _construct
init_fn(script_module)
File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 491, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 513, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 587, in _construct
init_fn(script_module)
File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 491, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 465, in create_script_module_impl
hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 758, in get_hook_stubs
if pre_hook.__name__ in hook_map:
AttributeError: 'LnStructured' object has no attribute '__name__'
Here is the code to reproduce the error :
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.jit as jit
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.Sigmoid(),
)
def forward(self, t):
return self.layers(t)
model = Model()
for _, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
prune.ln_structured(module, name='weight', amount=0.4, dim=1, n=float('-inf'))
jit.script(model)
Thanks,
Thytu