Tensorboard `add_graph`: visualize model partially composed from torch.distributions

I’m experimenting with a model that uses a GMM at the output, implemented using functionality from torch.distributions. Can this be visualized via tensorboard?

When calling writer.add_graph on the model, I get the following error:

/opt/conda/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py in add_graph(self, model, input_to_model, verbose)
    705         if hasattr(model, 'forward'):
    706             # A valid PyTorch model should have a 'forward' method
--> 707             self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
    708         else:
    709             # Caffe2 models do not have the 'forward' method

/opt/conda/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py in graph(model, args, verbose)
    289             print(e)
    290             print('Error occurs, No graph saved')
--> 291             raise e
    292 
    293     if verbose:

/opt/conda/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py in graph(model, args, verbose)
    283     with torch.onnx.set_training(model, False):  # TODO: move outside of torch.onnx?
    284         try:
--> 285             trace = torch.jit.trace(model, args)
    286             graph = trace.graph
    287             torch._C._jit_pass_inline(graph)

/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace, _module_class, _compilation_unit)
    873         return trace_module(func, {'forward': example_inputs}, None,
    874                             check_trace, wrap_check_inputs(check_inputs),
--> 875                             check_tolerance, _force_outplace, _module_class)
    876 
    877     if (hasattr(func, '__self__') and isinstance(func.__self__, torch.nn.Module) and

/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace, _module_class, _compilation_unit)
   1025             func = mod if method_name == "forward" else getattr(mod, method_name)
   1026             example_inputs = make_tuple(example_inputs)
-> 1027             module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)
   1028             check_trace_method = module._c._get_method(method_name)
   1029 

RuntimeError: Tracer cannot infer type of (MixtureSameFamily(
  Categorical(<redacted>),
  MultivariateNormal(<redacted>)))
:Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type MixtureSameFamily.

The last line clarifies in no uncertain terms :slight_smile: . I am curious what workarounds there are.