You can wrap your model in another class that converts outputs from dict / list into namedtuple / tuple for graph visualization.
The usage is straightforward
model_wrapper = ModelWrapper(model)
writer.add_graph(model_wrapper, input_image)
This class this could probably help and will work as temporary fix.
from collections import namedtuple
from typing import Any
import torch
# pylint: disable = abstract-method
class ModelWrapper(torch.nn.Module):
"""
Wrapper class for model with dict/list rvalues.
"""
def __init__(self, model: torch.nn.Module) -> None:
"""
Init call.
"""
super().__init__()
self.model = model
def forward(self, input_x: torch.Tensor) -> Any:
"""
Wrap forward call.
"""
data = self.model(input_x)
if isinstance(data, dict):
data_named_tuple = namedtuple("ModelEndpoints", sorted(data.keys())) # type: ignore
data = data_named_tuple(**data) # type: ignore
elif isinstance(data, list):
data = tuple(data)
return data