# Recursively transforming Pytorch code to JIT script?

Many of the Pytorch tutorials seem to show a combination of tracing (`torch.jit.trace`) and scripting (`torch.jit.script`)

However it seems like if all the code in a given `nn.Module` (even if it contains other modules) is compatible with Torchscript, then do you only need a single line to convert, ie: `torch.jit.script(my_net_instance)` ?

See this example:

``````import torch

class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.flatten().sum() > 0:
return x + 1
else:
return -x

class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)

def forward(self, x, h: int):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h

class MyOuterModule(torch.nn.Module):
def __init__(self, my_cell):
super(MyOuterModule, self).__init__()
self.my_cell = my_cell

def forward(self, x):
new_h, _ = self.my_cell.forward(x, 1)
return new_h / new_h.sum()

# normal invocation
x = torch.rand(3, 4)
net = MyOuterModule(MyCell(MyDecisionGate()))
original_output = net(x)
print(original_output)

# use the script compiler to recursely turn into
# torchscript
scripted_net = torch.jit.script(net)
torchscript_output = scripted_net(x)
print(torchscript_output)

# check if they are close enough!
assert torch.allclose(original_output, torchscript_output), \
"torchscript output does not match original"
``````

Is this the correct way to think about things?

I am trying to understand better how Torchscript generation works.

Yes, you wouldnâ€™t need to script submodules assuming no special conditions or tracing has to be used and should directly be able to script the model via `torch.jit.script`.

1 Like

Great, thanks for confirming!