class LinearInLinear(nn.Module):
def __init__(self):
super(LinearInLinear, self).__init__()
self.l = nn.Linear(3, 5)
self.l1 = nn.Linear(5, 5)
def forward(self, x):
return self.l1(self.l(x + x))
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.l1 = LinearInLinear()
self.l = LinearInLinear()
def forward(self, input):
x1 = self.l1(input)
x2 = self.l(input)
return x1 + x2 + x2
if __name__ == "__main__":
dummy_input = (torch.rand(5, 3),)
# with torch.onnx.set_training(model, False):
trace = torch.jit.trace(model, dummy_input)
# _optimize_trace(trace, torch._C._onnx.OperatorExportTypes.ONNX)
trace.save("b.pt")
print(trace.graph)
for node in trace.graph.nodes():
if node.kind() == "prim::Constant":
continue
print(list(node.outputs())[0].type().scalarType())
print(type(trace))
k = torch.jit.load("b.pt")
print(type(k))
print(k.graph)
for node in k.graph.nodes():
if node.kind() == "prim::Constant":
continue
print(node)
output = list(node.outputs())[0]
print(output.type().scalarType())
the last
the last sentence produce the error below
RuntimeError: r ASSERT FAILED at /pytorch/aten/src/ATen/core/jit_type.h:142, please report a bug to PyTorch.