Are jit.script & prune.ln_structured incompatible?

Hello everyone,

I would like to prune my model and then script it but I receive the following error: AttributeError: 'LnStructured' object has no attribute '__name__'

Can we prune a model and then script it?

I see that LnStructured actually add a forward pre_hook but jit.script can’t resolve its name.

cf:

Traceback (most recent call last):
  File "src/model_optimizer.py", line 226, in <module>
    script_module(pruned_model)
  File "src/model_optimizer.py", line 20, in script_module
    module = jit.script(module)
  File "/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 1257, in script
    return torch.jit._recursive.create_script_module(
  File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 451, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 513, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 587, in _construct
    init_fn(script_module)
  File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 491, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 513, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 587, in _construct
    init_fn(script_module)
  File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 491, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 465, in create_script_module_impl
    hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
  File "/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 758, in get_hook_stubs
    if pre_hook.__name__ in hook_map:
AttributeError: 'LnStructured' object has no attribute '__name__'

Here is the code to reproduce the error :

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.jit as jit

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(10, 10),
            nn.ReLU(),

            nn.Linear(10, 10),
            nn.Sigmoid(),
        )

    def forward(self, t):
        return self.layers(t)

model = Model()

for _, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        prune.ln_structured(module, name='weight', amount=0.4, dim=1, n=float('-inf'))

jit.script(model)

Thanks,
Thytu