How to visualize/draw a model

@soulitzer

Thank you very much for sharing the reference. I followed it but couldn’t make it. It seems that Mask R-CNN does not return a single tensor but a dictionary. Therefore add_graph() is showing errors. I found a potential workaround here, but getting the following error:

Error occurs, No graph saved
Traceback (most recent call last):
  File "maskrcnn_vis.py", line 28, in <module>
    writer.add_graph(model_wrapper, dummy)
  File "/home/ravi/anaconda/envs/torchenv/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py", line 727, in add_graph
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
  File "/home/ravi/anaconda/envs/torchenv/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 292, in graph
    raise e
  File "/home/ravi/anaconda/envs/torchenv/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 286, in graph
    trace = torch.jit.trace(model, args)
...
...
...
RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.

Below is the version information:

Python 3.7.13 (default, Mar 29 2022, 02:18:16) 

In [1]: import torch
In [2]: import torchvision
In [3]: import tensorboard

In [4]: torch.__version__
Out[4]: '1.9.0+cu111'

In [5]: torchvision.__version__
Out[5]: '0.10.0+cu111'

In [6]: tensorboard.__version__
Out[6]: '2.11.2'

In [7]: torch.version.cuda
Out[7]: '11.1'

In [8]: torch.cuda.is_available()
Out[8]: True

The sample code is below:

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from collections import namedtuple
from typing import Any


class ModelWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module) -> None:
        super().__init__()
        self.model = model

    def forward(self, input_x: torch.Tensor) -> Any:
        data = self.model(input_x)
        if isinstance(data, dict):
            data_named_tuple = namedtuple("ModelEndpoints", sorted(data.keys()))
            data = data_named_tuple(**data)
        elif isinstance(data, list):
            data = tuple(data)
        return data


model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model_wrapper = ModelWrapper(model)
dummy = torch.rand((4, 3, 64, 64))

writer = SummaryWriter("logs/maskrcnn_resnet50_fpn")
writer.add_graph(model_wrapper, dummy)
writer.close()