Transfer Learning: Adding graph to tensorboard, with the input

I was wondering if there is a quick way to see what input a model wants. As surely the model knows what input it needs. And there aren’t tf.placeholders

Specifically, working through this tutoiral:

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,

    return model

I cannot add this model to tensorboard to view the graph, even when I add the summary writer in the training loop.

1 Like

Eg. I don’t understand why this doesn’t work. surely the model and it’s input are being added.

writer = SummaryWriter()

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
data_loader =, batch_size=2, shuffle=True, num_workers=4,collate_fn=utils.collate_fn)
# For Training
images,targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
#writer.add_graph(model, (images,targets))
output = model(images,targets)   # Returns losses and detections
# For inference
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
writer.add_graph(model, x)

predictions = model(x)   

I have the same question with you. Do you find a solution for this?
According to the error log, I think the problem is the return type of model, which means the output variable, is dict. However, the add_graph need a model which return a tensor or a list, tuple of tensors.
I try to wrapper the model with changing the output to list or tuple,

# reference to
def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
    return out_dict["boxes"], out_dict["scores"], out_dict["labels"]

class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        self.model = model

    def forward(self, inp, targets = none):
        out = self.model(inp, targets)
        return dict_to_tuple(out[0])

, I encounter another issue:

TracingCheckError                         Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_29156/ in <module>
     78     print("That's it!")
---> 80 main()

~\AppData\Local\Temp/ipykernel_29156/ in main()
     50     writer = SummaryWriter(logs_dir)
     51     tracedModel = TraceWrapper(model)
---> 52     writer.add_graph(tracedModel, torch.randn(1, 3, 224, 224))
     53     writer.close()

D:\Programs\Python\Python39\lib\site-packages\torch\utils\tensorboard\ in add_graph(self, model, input_to_model, verbose)
    725         if hasattr(model, 'forward'):
    726             # A valid PyTorch model should have a 'forward' method
--> 727             self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
    728         else:
    729             # Caffe2 models do not have the 'forward' method

D:\Programs\Python\Python39\lib\site-packages\torch\utils\tensorboard\ in graph(model, args, verbose)
    284     with torch.onnx.select_model_mode_for_export(model, torch.onnx.TrainingMode.EVAL):  # TODO: move outside of torch.onnx?
    285         try:
--> 286             trace = torch.jit.trace(model, args)
    287             graph = trace.graph
    288             torch._C._jit_pass_inline(graph)

D:\Programs\Python\Python39\lib\site-packages\torch\jit\ in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    734     if isinstance(func, torch.nn.Module):
--> 735         return trace_module(
    736             func,
    737             {"forward": example_inputs},

D:\Programs\Python\Python39\lib\site-packages\torch\jit\ in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    975                     )
    976                 else:
--> 977                     _check_trace(
    978                         [inputs],
    979                         func,

D:\Programs\Python\Python39\lib\site-packages\torch\autograd\ in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)

D:\Programs\Python\Python39\lib\site-packages\torch\jit\ in _check_trace(check_inputs, func, traced_func, check_tolerance, strict, force_outplace, is_trace_module, _module_class)
    519         diag_info = graph_diagnostic_info()
    520         if any(info is not None for info in diag_info):
--> 521             raise TracingCheckError(*diag_info)
First diverging operator:
TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
	Graph diff:
		  graph(%self.1 : __torch__.TraceWrapper,
		        %input.1 : Tensor):
	Node diff:
		- %2 : __torch__.torchvision.models.detection.mask_rcnn.___torch_mangle_3257.MaskRCNN = prim::GetAttr[name="model"](%self.1)
		?                                                                        ^ ^
		+ %2 : __torch__.torchvision.models.detection.mask_rcnn.___torch_mangle_3450.MaskRCNN = prim::GetAttr[name="model"](%self.1)

I think because I add a wrapper, get the value of named “model” could be error.

I have the same issue as you, were you able to resolve this? I used the TraceWrapper as well, and also got diverging nodes