Dynamically skip some front part of torchscript model

Hello

I’m using 1.5.0 version of PyTorch.
I’d like to skip i numbers of layers in a Torchscript model.
Below code is my trial. The method ‘skip’ takes input tensor x and skip point indicator i.

import torch
import torch.nn as nn

class Model(torch.nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.fc1 = nn.Linear(20,100)
            self.fc2 = nn.Linear(100, 100)
            self.fc3 = nn.Linear(100, 1)
            self.k = nn.ModuleList()
            self.k.append(self.fc1)
            self.k.append(self.fc2)
            self.k.append(self.fc3)

        @torch.jit.export
        def skip(self, x, i):
            for l in range(int(i), 3):
                x = self.k[l].forward(x)
            return x

        def forward(self, x):
            x = self.fc1(x)
            x = self.fc2(x)
            x = self.fc3(x)
            return x

model = Model().cuda()
x = torch.randn(20).cuda()
x2 = torch.randn(100).cuda() # Assume that x2 is the result of self.fc1 which is computed in advance.
idx = torch.tensor(1).cuda()
print(model.skip(x2,idx)) # This works

traced = torch.jit.trace(model, x) # Script
print(traced.skip(x2,idx)) # This does not works. Error occurs.

The error message is printed as below.

Traceback (most recent call last):
  File "C:/Users/user/PycharmProjects/coin-modelgen/simple.py", line 33, in <module>
    traced = torch.jit.trace(model, x) # Script
  File "C:\Users\user\Anaconda3\envs\pytorch\lib\site-packages\torch\jit\__init__.py", line 875, in trace
    check_tolerance, _force_outplace, _module_class)
  File "C:\Users\user\Anaconda3\envs\pytorch\lib\site-packages\torch\jit\__init__.py", line 1021, in trace_module
    module = make_module(mod, _module_class, _compilation_unit)
  File "C:\Users\user\Anaconda3\envs\pytorch\lib\site-packages\torch\jit\__init__.py", line 716, in make_module
    return torch.jit._recursive.create_script_module(mod, make_stubs_from_exported_methods, share_types=False)
  File "C:\Users\user\Anaconda3\envs\pytorch\lib\site-packages\torch\jit\_recursive.py", line 305, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "C:\Users\user\Anaconda3\envs\pytorch\lib\site-packages\torch\jit\_recursive.py", line 361, in create_script_module_impl
    create_methods_from_stubs(concrete_type, stubs)
  File "C:\Users\user\Anaconda3\envs\pytorch\lib\site-packages\torch\jit\_recursive.py", line 279, in create_methods_from_stubs
    concrete_type._create_methods(defs, rcbs, defaults)
RuntimeError: 
Expected integer literal for index:
  File "C:/Users/user/PycharmProjects/coin-modelgen/simple.py", line 18
        def skip(self, x, i):
            for l in range(int(i), 3):
                x = self.k[l].forward(x)
                    ~~~~~~~~ <--- HERE
            return x

Could someone can share any idea (or alternative approach) to solve this problem?

Moreover, I want to know that such dynamically skipping is also possible or not in c++ implementation using libtorch.

According to torch.jit.script cannot index a ModuleList with an int returned from range() · Issue #47496 · pytorch/pytorch · GitHub, I identified that the i must be selected in compile time not runtime. Thus…, is this meaning that there no way to dynamically skip the front part of the Torchscript model?

Or, is such naive below code a unique way?

import torch
import torch.nn as nn

class Model(torch.nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.fc1 = nn.Linear(20,100)
            self.fc2 = nn.Linear(100, 100)
            self.fc3 = nn.Linear(100, 1)
            self.k = nn.ModuleList()
            self.k.append(self.fc1)
            self.k.append(self.fc2)
            self.k.append(self.fc3)
            # self.k = nn.Sequential(self.fc1,self.fc2,self.fc3 )

        def from1(self, x):
            for index, v in enumerate(self.k[1:]):
                x = v.forward(x)
            return x

        def from2(self, x):
            for index, v in enumerate(self.k[2:]):
                x = v.forward(x)
            return x

        @torch.jit.export
        def skip(self, x, i:int):
            if i == 1:
                x = self.from1(x)
            elif i == 2:
                x = self.from2(x)
            return x

        def forward(self, x):
            x = self.fc1(x)
            x = self.fc2(x)
            x = self.fc3(x)
            return x

model = Model().cuda()
m = torch.jit.script(Model()).cuda()

x = torch.randn(20).cuda()
x2 = torch.randn(100).cuda() # Assume that x is the result of self.fc1 which is computed in advance.
idx = 1
print(model.skip(x2,idx)) # This works
print(m.skip(x2,idx)) # This works
tensor([-0.3310], device='cuda:0', grad_fn=<AddBackward0>)
tensor([-0.2432], device='cuda:0', grad_fn=<AddBackward0>)

If this way is the unique way, why the two result are different?

Thanks a lot.

One thing you can do is iterate over the entire module list with enumerate and only call each module if i is greater than some limit k. Performance will be worse than if you skipped the first k layers completely.

You can dynamically skip the front part but only by annotating the LHS of your indexing statement with a module interface type. See the issue you linked for more details. I am currently working on extending this for ModuleList, it currently only works for ModuleDict.

I think it’s because from1 and from2 perform different computations?