How to successfully symbolic trace detection models to fx graph?

Hi, there. I try to use torch.fx to trace the detection models in torchvision, looks some models can not been traced for now. Is there a way to get the graph for models like ssd, or is there some plans to address these issues in the future?

import torch
import torchvision
ssd300_vgg16 = torchvision.models.detection.ssd300_vgg16()

fx_ssd300_vgg16 : torch.fx.GraphModule = symbolic_trace(ssd300_vgg16)
fx_ssd300_vgg16.graph
~/anaconda3/envs/base/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
---------------------------------------------------------------------------
TraceError                                Traceback (most recent call last)
<ipython-input-12-ca9b053ea500> in <module>
      3 ssd300_vgg16 = torchvision.models.detection.ssd300_vgg16()
      4 
----> 5 fx_ssd300_vgg16 : torch.fx.GraphModule = symbolic_trace(ssd300_vgg16)
      6 print(fx_ssd300_vgg16.graph)

~/anaconda3/envs/base/lib/python3.8/site-packages/torch/fx/symbolic_trace.py in symbolic_trace(root, concrete_args, enable_cpatching)
    857     """
    858     tracer = Tracer(enable_cpatching=enable_cpatching)
--> 859     graph = tracer.trace(root, concrete_args)
    860     name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
    861     return GraphModule(tracer.root, graph, name)

~/anaconda3/envs/base/lib/python3.8/site-packages/torch/fx/symbolic_trace.py in trace(self, root, concrete_args)
    569                 for module in self._autowrap_search:
    570                     _autowrap_check(patcher, module.__dict__, self._autowrap_function_ids)
--> 571                 self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
    572                                  type_expr=fn.__annotations__.get('return', None))
    573 

~/anaconda3/envs/base/lib/python3.8/site-packages/torchvision/models/detection/ssd.py in forward(self, images, targets)
    289         if self.training:
    290             assert targets is not None
--> 291             for target in targets:
    292                 boxes = target["boxes"]
    293                 if isinstance(boxes, torch.Tensor):

~/anaconda3/envs/base/lib/python3.8/site-packages/torch/fx/proxy.py in __iter__(self)
    194             return (self[i] for i in range(inst.argval))  # type: ignore[index]
    195 
--> 196         return self.tracer.iter(self)
    197 
    198     def __bool__(self) -> bool:

~/anaconda3/envs/base/lib/python3.8/site-packages/torch/fx/proxy.py in iter(self, obj)
    135         information to the graph node using create_node and can choose to return an iterator.
    136         """
--> 137         raise TraceError('Proxy object cannot be iterated. '
    138                          'This can be attempted when used in a for loop or as a *args or **kwargs function argument.')
    139 

TraceError: Proxy object cannot be iterated. This can be attempted when used in a for loop or as a *args or **kwargs function argument.
1 Like