A Tensorboard problem about use add_graph method for deeplab-v3 in Torchvision.

I want to generate deeplab-v3 graph on Tensorboard so I tried the below code refering the pytorch documentation.
(https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html)

import torch
from torch import nn
from torchvision import models
from torch.utils.tensorboard import SummaryWriter

deeplabv3 = models.segmentation.deeplabv3_resnet50(pretrained=True)
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/deeplabv3')
x = torch.rand(1,3,512,512)
writer.add_graph(deeplabv3, x)

And I get next error.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/Documents/4.自学自習/DeepLearning/SemanticSegmentation/torch_test.py in 
      293 breakpoint()
      294 
----> 295 writer.add_graph(deeplabv3, x)

~/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py in add_graph(self, model, input_to_model, verbose)
    712         if hasattr(model, 'forward'):
    713             # A valid PyTorch model should have a 'forward' method
--> 714             self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
    715         else:
    716             # Caffe2 models do not have the 'forward' method

~/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py in graph(model, args, verbose)
    289             print(e)
    290             print('Error occurs, No graph saved')
--> 291             raise e
    292 
    293     if verbose:

~/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py in graph(model, args, verbose)
    283     with torch.onnx.select_model_mode_for_export(model, torch.onnx.TrainingMode.EVAL):  # TODO: move outside of torch.onnx?
    284         try:
--> 285             trace = torch.jit.trace(model, args)
    286             graph = trace.graph
    287             torch._C._jit_pass_inline(graph)

~/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    953         return trace_module(func, {'forward': example_inputs}, None,
    954                             check_trace, wrap_check_inputs(check_inputs),
--> 955                             check_tolerance, strict, _force_outplace, _module_class)
    956 
    957     if (hasattr(func, '__self__') and isinstance(func.__self__, torch.nn.Module) and

~/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/jit/__init__.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
   1107             func = mod if method_name == "forward" else getattr(mod, method_name)
   1108             example_inputs = make_tuple(example_inputs)
-> 1109             module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, strict, _force_outplace)
   1110             check_trace_method = module._c._get_method(method_name)
   1111 

RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.

When I tried it with vgg16, I could generate the model grap.
So I guess I have to do some adjusting code for torchvision or complicated model but I do not come up with that.
I would appreciate it if someone would help me.

~ versions ~

  • torch 1.6.0
  • torchvision 0.7.0
  • tensorboard 2.3.0

You can wrap your model in another class that converts outputs from dict / list into namedtuple / tuple for graph visualization.

The usage is straightforward

model_wrapper = ModelWrapper(model)
writer.add_graph(model_wrapper, input_image)

This class this could probably help and will work as temporary fix.

from collections import namedtuple
from typing import Any

import torch


# pylint: disable = abstract-method
class ModelWrapper(torch.nn.Module):
    """
    Wrapper class for model with dict/list rvalues.
    """

    def __init__(self, model: torch.nn.Module) -> None:
        """
        Init call.
        """
        super().__init__()
        self.model = model

    def forward(self, input_x: torch.Tensor) -> Any:
        """
        Wrap forward call.
        """
        data = self.model(input_x)

        if isinstance(data, dict):
            data_named_tuple = namedtuple("ModelEndpoints", sorted(data.keys()))  # type: ignore
            data = data_named_tuple(**data)  # type: ignore

        elif isinstance(data, list):
            data = tuple(data)

        return data