pytorch version: 2.0.1+cu117
after I downgrade the version to 1.8.2 the problem solved
I traced a model and do the inference by three ways:
# trace model
traced_model = torch.jit.trace(traceable_model, exaample_input, check_tolerance=1e-9)
# pytorch model
expected_outputs = traceable_model(*example_input)
# run inference immediately
traced_outputs = traced_model(*example_input)
# dump model
dumped_model_path = args.output_path
logger.info("Dump TorchScript model to: {}.".format(dumped_model_path))
traced_model.save(dumped_model_path)
logger.info("loading from {}".format(dumped_model_ppath)
# load and inference
pt_model = torch.jit.load(dumped_model_path)
pt_outputs = pt_model(*example_input)
comparing the result
def compare(expected_outputs, traced_outputs):
names = ('anchors', ' scores', ' attr_scores', 'box_delta')
error = 0
for name, item1, item2 in zip(names, expected_outputss, traced_outputs)
print('name:', name)
for o1, o2 in zip(item1, item2):
if isinstance(o1, torch.Size):
assert torch.eq(torch.as_tensor(o1), o2).all(), "{} oHiffers {}.".format(o1, o2)
continue
elif isinstance(o1, torch.Tensor):
if o1.dtype == torch.bool:
diff = o1 ^ o2
else:
diff = torch.abs(o1 - o2)
logger.info(f'max diff: {torch.max(diff)}, {torch.max(diff/torch.min(torch.abs(o1), torch.abs(o2)))}}')
error += diff.sum()
else:
logger.warning("unrecognized data type {} for data {}. Skiip compute tracing error on this data.".format
type(o1), o1
if error > 0:
logger.warning("Errors on example input is {}.".format(error)
compare_diff(expected_outputs, traced_outputs) # different
logger.info('comparing pt outputs and excepted outputs')
compare_diff(pt_outputs, excepted_outputs) # no diff
As shown below, excepted_outputs
is same as pt_outputs
, but not equal to traced_outputs
.