Recursively transforming Pytorch code to JIT script?

Many of the Pytorch tutorials seem to show a combination of tracing (torch.jit.trace) and scripting (torch.jit.script)

However it seems like if all the code in a given nn.Module (even if it contains other modules) is compatible with Torchscript, then do you only need a single line to convert, ie: torch.jit.script(my_net_instance) ?

See this example:

import torch

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.flatten().sum() > 0:
            return x + 1
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h: int):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h
    
class MyOuterModule(torch.nn.Module):
    def __init__(self, my_cell):
        super(MyOuterModule, self).__init__()
        self.my_cell = my_cell

    def forward(self, x):
        new_h, _ = self.my_cell.forward(x, 1)
        return new_h / new_h.sum()

# normal invocation
x = torch.rand(3, 4)
net = MyOuterModule(MyCell(MyDecisionGate()))
original_output = net(x)
print(original_output)

# use the script compiler to recursely turn into
# torchscript
scripted_net = torch.jit.script(net)
torchscript_output = scripted_net(x)
print(torchscript_output)

# check if they are close enough!
assert torch.allclose(original_output, torchscript_output), \
    "torchscript output does not match original"

Is this the correct way to think about things?

I am trying to understand better how Torchscript generation works.

Yes, you wouldn’t need to script submodules assuming no special conditions or tracing has to be used and should directly be able to script the model via torch.jit.script.

1 Like

Great, thanks for confirming!