How to visualize/draw a model

Hi,

I have a model from torchvision say Mask R-CNN. I wish to visualize/draw this model. For example, please see a sample below:


Image Source: szagoruyko/pytorchviz

My model is initialized as shown below:

import torchvision
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

One way is to simply use print(model) to see the details. However, the skip connections, branch etc are lost with print statement.

So, how to visualize/draw a model?

Tensorboard has a functionality to display pytorch models Visualizing Models, Data, and Training with TensorBoard — PyTorch Tutorials 2.0.0+cu117 documentation

If you need python-only solution with convenient customization, see torchview, which I am the author of

here: GitHub - mert-kurttutan/torchview: torchview: visualize pytorch models

It is as simple as

import torchvision

model_graph = draw_graph(resnet18(), input_size=(1,3,32,32), expand_nested=True)
model_graph.visual_graph

The result:

1 Like

@soulitzer

Thank you very much for sharing the reference. I followed it but couldn’t make it. It seems that Mask R-CNN does not return a single tensor but a dictionary. Therefore add_graph() is showing errors. I found a potential workaround here, but getting the following error:

Error occurs, No graph saved
Traceback (most recent call last):
  File "maskrcnn_vis.py", line 28, in <module>
    writer.add_graph(model_wrapper, dummy)
  File "/home/ravi/anaconda/envs/torchenv/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py", line 727, in add_graph
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
  File "/home/ravi/anaconda/envs/torchenv/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 292, in graph
    raise e
  File "/home/ravi/anaconda/envs/torchenv/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 286, in graph
    trace = torch.jit.trace(model, args)
...
...
...
RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.

Below is the version information:

Python 3.7.13 (default, Mar 29 2022, 02:18:16) 

In [1]: import torch
In [2]: import torchvision
In [3]: import tensorboard

In [4]: torch.__version__
Out[4]: '1.9.0+cu111'

In [5]: torchvision.__version__
Out[5]: '0.10.0+cu111'

In [6]: tensorboard.__version__
Out[6]: '2.11.2'

In [7]: torch.version.cuda
Out[7]: '11.1'

In [8]: torch.cuda.is_available()
Out[8]: True

The sample code is below:

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from collections import namedtuple
from typing import Any


class ModelWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module) -> None:
        super().__init__()
        self.model = model

    def forward(self, input_x: torch.Tensor) -> Any:
        data = self.model(input_x)
        if isinstance(data, dict):
            data_named_tuple = namedtuple("ModelEndpoints", sorted(data.keys()))
            data = data_named_tuple(**data)
        elif isinstance(data, list):
            data = tuple(data)
        return data


model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model_wrapper = ModelWrapper(model)
dummy = torch.rand((4, 3, 64, 64))

writer = SummaryWriter("logs/maskrcnn_resnet50_fpn")
writer.add_graph(model_wrapper, dummy)
writer.close()

@ConvolutionalAtom

Thank you so much for introducing torchview. It looks straightforward. Unfortunatly, I am not getting a fancy graph as yours. Please see below a screenshot:

Yeah, this is actually in one of the issues.

The thing is that this detection model uses Imagelist object when transferring tensors among some of its submodules. But, torchview supports, tensor, any mappable object, any iterable object. Since Imagelist object is neither mappable, nor iterable, its tensor content is not record by torchview.

Currently working on this.

@ConvolutionalAtom

Thank you very much. I understand. I am looking forward to the fix.

Meanwhile, is there any workaround for now?

You can also use Netron software to visualize your model and tensor. “GitHub - lutzroeder/netron: Visualizer for neural network, deep learning, and machine learning models

@Mohamed_Nabih

Thank you very much. I tired it way before posting my question.

This tool is asking for weight file (state dictionary) and I supplied it. It rendered some graphic but without any node connecting to each other. All I could see some nodes.

It seems that visualizing Mask R-CNN is not supported by this tool.

I am looking forward hearing your opinion on possible workaround.

Thanks again.

Try tensorboard:

pip install tensorboard

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
dummy_input = torch.rand(1, 3, 800, 800)

writer = SummaryWriter("runs/maskrcnn")
writer.add_graph(model, dummy_input)
writer.close()

Now, run TensorBoard in your terminal:

tensorboard --logdir runs

Open the provided URL in your browser to view the model in TensorBoard.

@AbdulsalamBande

Thank you very much. I tried it before, as mentioned here. However, upon your suggestion, I attempted it again and found the following error:

  File "/home/ravi/tools/anaconda/envs/visenv/lib/python3.7/site-packages/torch/jit/_trace.py", line 959, in trace_module
    argument_names,
RuntimeError: Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions

Thank you very much

This is probably due to the same error I mentioned about the usage of ImageList object.

@ConvolutionalAtom

Yes, you are right. The same cause is making errors on everywhere.

Please let me know if there is any workaround or if your fix is implemented.

Thanks again

I recently released a package, TorchLens, for visualizing arbitrary PyTorch models. I verified that it works for maskrcnn_resnet50_fpn–the full visual is quite daunting (and too big to attach here) since TorchLens captures every operation in the model’s forward pass, but note that TorchLens also has options to visualize a model at different levels of nesting.