I’m experimenting with a model that uses a GMM at the output, implemented using functionality from torch.distributions
. Can this be visualized via tensorboard?
When calling writer.add_graph
on the model, I get the following error:
/opt/conda/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py in add_graph(self, model, input_to_model, verbose)
705 if hasattr(model, 'forward'):
706 # A valid PyTorch model should have a 'forward' method
--> 707 self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
708 else:
709 # Caffe2 models do not have the 'forward' method
/opt/conda/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py in graph(model, args, verbose)
289 print(e)
290 print('Error occurs, No graph saved')
--> 291 raise e
292
293 if verbose:
/opt/conda/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py in graph(model, args, verbose)
283 with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx?
284 try:
--> 285 trace = torch.jit.trace(model, args)
286 graph = trace.graph
287 torch._C._jit_pass_inline(graph)
/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace, _module_class, _compilation_unit)
873 return trace_module(func, {'forward': example_inputs}, None,
874 check_trace, wrap_check_inputs(check_inputs),
--> 875 check_tolerance, _force_outplace, _module_class)
876
877 if (hasattr(func, '__self__') and isinstance(func.__self__, torch.nn.Module) and
/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace, _module_class, _compilation_unit)
1025 func = mod if method_name == "forward" else getattr(mod, method_name)
1026 example_inputs = make_tuple(example_inputs)
-> 1027 module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)
1028 check_trace_method = module._c._get_method(method_name)
1029
RuntimeError: Tracer cannot infer type of (MixtureSameFamily(
Categorical(<redacted>),
MultivariateNormal(<redacted>)))
:Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type MixtureSameFamily.
The last line clarifies in no uncertain terms . I am curious what workarounds there are.