Scripting Incompatible With Repeated Branches

I’m trying to export a PyTorch model to TorchScript via scripting and I am stuck. I’ve created a toy class to showcase the issue:

import torch
from torch import nn


class SadModule(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self, use_skip: bool):
        nn.Module.__init__(self)
        self.use_skip = use_skip
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        if self.use_skip:
            x_input = x
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
        return x

It basically consists of only a linear layer and an optional skip connection. If I try to script the model using

mod1 = SadModule(False)
scripted_mod1 = torch.jit.script(mod)

I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-10-a7ebc7af32c7> in <module>
----> 1 scripted_mod1 = torch.jit.script(mod)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
    895
    896     if isinstance(obj, torch.nn.Module):
--> 897         return torch.jit._recursive.create_script_module(
    898             obj, torch.jit._recursive.infer_methods_to_compile
    899         )

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    350     check_module_initialized(nn_module)
    351     concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    353
    354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    408     # Compile methods if necessary
    409     if concrete_type not in concrete_type_store.methods_compiled:
--> 410         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    411         torch._C._run_emit_module_hook(cpp_module)
    412         concrete_type_store.methods_compiled.add(concrete_type)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    302     property_rcbs = [p.resolution_callback for p in property_stubs]
    303
--> 304     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    305
    306

RuntimeError:

x_input is not defined in the false branch:
  File "<ipython-input-7-d08ed7ff42ec>", line 12
    def forward(self, x):
        if self.use_skip:
        ~~~~~~~~~~~~~~~~~
            x_input = x
            ~~~~~~~~~~~ <--- HERE
        x = self.layer(x)
        if self.use_skip:
and was used here:
  File "<ipython-input-7-d08ed7ff42ec>", line 16
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
                    ~~~~~~~ <--- HERE
        return x

So, basically TorchScript isn’t able to recognise that for mod1 the True branch of either if statement won’t ever be used. Moreover, if we create an instance that actually uses the skip connection,

mod2 = SadModule(True)
scripted_mod2 = torch.jit.script(mod2)

we will get another error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-b5ca61d8aa73> in <module>
----> 1 scripted_mod2 = torch.jit.script(mod2)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
    895
    896     if isinstance(obj, torch.nn.Module):
--> 897         return torch.jit._recursive.create_script_module(
    898             obj, torch.jit._recursive.infer_methods_to_compile
    899         )

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    350     check_module_initialized(nn_module)
    351     concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    353
    354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    408     # Compile methods if necessary
    409     if concrete_type not in concrete_type_store.methods_compiled:
--> 410         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    411         torch._C._run_emit_module_hook(cpp_module)
    412         concrete_type_store.methods_compiled.add(concrete_type)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    302     property_rcbs = [p.resolution_callback for p in property_stubs]
    303
--> 304     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    305
    306

RuntimeError:

x_input is not defined in the false branch:
  File "<ipython-input-18-ac8b9713c789>", line 17
    def forward(self, x):
        if self.use_skip:
        ~~~~~~~~~~~~~~~~~
            x_input = x
            ~~~~~~~~~~~ <--- HERE
        x = self.layer(x)
        if self.use_skip:
and was used here:
  File "<ipython-input-18-ac8b9713c789>", line 21
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
                    ~~~~~~~ <--- HERE
        return x

So in this case TorchScript doesn’t understand that both ifs will always be true and that in fact x_input is well defined.

To avoid the issue, I could split the class into two subclasses, as in:

class SadModuleNoSkip(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self):
        nn.Module.__init__(self)
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        x = self.layer(x)
        return x

class SadModuleSkip(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self):
        nn.Module.__init__(self)
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        x_input = x
        x = self.layer(x)
        x = x + x_input
        return x

However, I am working on a huge code base and I would have to repeat the process for many classes, which is time consuming and could introduce bugs. Moreover, often the modules I’m working on are huge convolutional nets and the ifs just control the presence of an additional batch normalization. It seems to me undesirable to have to classes that are identical in 99% of the blocks, save for a single batch norm layer.

Is there a way in which I can help TorchScript with its handling of branches?

I’ve opened an issue on GitHub, the maintainers were able to help me solve the problem.

Marking constant attributes with Final is the way to go, but it’s a feature that’s not yet available with the JIT compiler of the main releases (as of May 7, 2021).