Many of the Pytorch tutorials seem to show a combination of tracing (torch.jit.trace
) and scripting (torch.jit.script
)
However it seems like if all the code in a given nn.Module
(even if it contains other modules) is compatible with Torchscript, then do you only need a single line to convert, ie: torch.jit.script(my_net_instance)
?
See this example:
import torch
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.flatten().sum() > 0:
return x + 1
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h: int):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
class MyOuterModule(torch.nn.Module):
def __init__(self, my_cell):
super(MyOuterModule, self).__init__()
self.my_cell = my_cell
def forward(self, x):
new_h, _ = self.my_cell.forward(x, 1)
return new_h / new_h.sum()
# normal invocation
x = torch.rand(3, 4)
net = MyOuterModule(MyCell(MyDecisionGate()))
original_output = net(x)
print(original_output)
# use the script compiler to recursely turn into
# torchscript
scripted_net = torch.jit.script(net)
torchscript_output = scripted_net(x)
print(torchscript_output)
# check if they are close enough!
assert torch.allclose(original_output, torchscript_output), \
"torchscript output does not match original"
Is this the correct way to think about things?
I am trying to understand better how Torchscript generation works.