I have the same question with you. Do you find a solution for this?
According to the error log, I think the problem is the return type of model, which means the output variable, is dict. However, the add_graph need a model which return a tensor or a list, tuple of tensors.
I try to wrapper the model with changing the output to list or tuple,
# reference to https://giters.com/apache/tvm/issues/7971
def dict_to_tuple(out_dict):
if "masks" in out_dict.keys():
return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
return out_dict["boxes"], out_dict["scores"], out_dict["labels"]
class TraceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, inp, targets = none):
out = self.model(inp, targets)
return dict_to_tuple(out[0])
, I encounter another issue:
TracingCheckError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_29156/3516386276.py in <module>
78 print("That's it!")
79
---> 80 main()
~\AppData\Local\Temp/ipykernel_29156/3516386276.py in main()
50 writer = SummaryWriter(logs_dir)
51 tracedModel = TraceWrapper(model)
---> 52 writer.add_graph(tracedModel, torch.randn(1, 3, 224, 224))
53 writer.close()
54
D:\Programs\Python\Python39\lib\site-packages\torch\utils\tensorboard\writer.py in add_graph(self, model, input_to_model, verbose)
725 if hasattr(model, 'forward'):
726 # A valid PyTorch model should have a 'forward' method
--> 727 self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
728 else:
729 # Caffe2 models do not have the 'forward' method
D:\Programs\Python\Python39\lib\site-packages\torch\utils\tensorboard\_pytorch_graph.py in graph(model, args, verbose)
284 with torch.onnx.select_model_mode_for_export(model, torch.onnx.TrainingMode.EVAL): # TODO: move outside of torch.onnx?
285 try:
--> 286 trace = torch.jit.trace(model, args)
287 graph = trace.graph
288 torch._C._jit_pass_inline(graph)
D:\Programs\Python\Python39\lib\site-packages\torch\jit\_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
733
734 if isinstance(func, torch.nn.Module):
--> 735 return trace_module(
736 func,
737 {"forward": example_inputs},
D:\Programs\Python\Python39\lib\site-packages\torch\jit\_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
975 )
976 else:
--> 977 _check_trace(
978 [inputs],
979 func,
D:\Programs\Python\Python39\lib\site-packages\torch\autograd\grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30
D:\Programs\Python\Python39\lib\site-packages\torch\jit\_trace.py in _check_trace(check_inputs, func, traced_func, check_tolerance, strict, force_outplace, is_trace_module, _module_class)
519 diag_info = graph_diagnostic_info()
520 if any(info is not None for info in diag_info):
--> 521 raise TracingCheckError(*diag_info)
522
First diverging operator:
TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
Graph diff:
graph(%self.1 : __torch__.TraceWrapper,
%input.1 : Tensor):
...
Node diff:
- %2 : __torch__.torchvision.models.detection.mask_rcnn.___torch_mangle_3257.MaskRCNN = prim::GetAttr[name="model"](%self.1)
? ^ ^
+ %2 : __torch__.torchvision.models.detection.mask_rcnn.___torch_mangle_3450.MaskRCNN = prim::GetAttr[name="model"](%self.1)
?
I think because I add a wrapper, get the value of named “model” could be error.