A Tensorboard problem about use add_graph method for deeplab-v3 in Torchvision.

You can wrap your model in another class that converts outputs from dict / list into namedtuple / tuple for graph visualization.

The usage is straightforward

model_wrapper = ModelWrapper(model)
writer.add_graph(model_wrapper, input_image)

This class this could probably help and will work as temporary fix.

from collections import namedtuple
from typing import Any

import torch


# pylint: disable = abstract-method
class ModelWrapper(torch.nn.Module):
    """
    Wrapper class for model with dict/list rvalues.
    """

    def __init__(self, model: torch.nn.Module) -> None:
        """
        Init call.
        """
        super().__init__()
        self.model = model

    def forward(self, input_x: torch.Tensor) -> Any:
        """
        Wrap forward call.
        """
        data = self.model(input_x)

        if isinstance(data, dict):
            data_named_tuple = namedtuple("ModelEndpoints", sorted(data.keys()))  # type: ignore
            data = data_named_tuple(**data)  # type: ignore

        elif isinstance(data, list):
            data = tuple(data)

        return data

3 Likes