Hi, there. I try to use torch.fx to trace the detection models in torchvision, looks some models can not been traced for now. Is there a way to get the graph for models like ssd, or is there some plans to address these issues in the future?
import torch
import torchvision
ssd300_vgg16 = torchvision.models.detection.ssd300_vgg16()
fx_ssd300_vgg16 : torch.fx.GraphModule = symbolic_trace(ssd300_vgg16)
fx_ssd300_vgg16.graph
~/anaconda3/envs/base/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
---------------------------------------------------------------------------
TraceError Traceback (most recent call last)
<ipython-input-12-ca9b053ea500> in <module>
3 ssd300_vgg16 = torchvision.models.detection.ssd300_vgg16()
4
----> 5 fx_ssd300_vgg16 : torch.fx.GraphModule = symbolic_trace(ssd300_vgg16)
6 print(fx_ssd300_vgg16.graph)
~/anaconda3/envs/base/lib/python3.8/site-packages/torch/fx/symbolic_trace.py in symbolic_trace(root, concrete_args, enable_cpatching)
857 """
858 tracer = Tracer(enable_cpatching=enable_cpatching)
--> 859 graph = tracer.trace(root, concrete_args)
860 name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
861 return GraphModule(tracer.root, graph, name)
~/anaconda3/envs/base/lib/python3.8/site-packages/torch/fx/symbolic_trace.py in trace(self, root, concrete_args)
569 for module in self._autowrap_search:
570 _autowrap_check(patcher, module.__dict__, self._autowrap_function_ids)
--> 571 self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
572 type_expr=fn.__annotations__.get('return', None))
573
~/anaconda3/envs/base/lib/python3.8/site-packages/torchvision/models/detection/ssd.py in forward(self, images, targets)
289 if self.training:
290 assert targets is not None
--> 291 for target in targets:
292 boxes = target["boxes"]
293 if isinstance(boxes, torch.Tensor):
~/anaconda3/envs/base/lib/python3.8/site-packages/torch/fx/proxy.py in __iter__(self)
194 return (self[i] for i in range(inst.argval)) # type: ignore[index]
195
--> 196 return self.tracer.iter(self)
197
198 def __bool__(self) -> bool:
~/anaconda3/envs/base/lib/python3.8/site-packages/torch/fx/proxy.py in iter(self, obj)
135 information to the graph node using create_node and can choose to return an iterator.
136 """
--> 137 raise TraceError('Proxy object cannot be iterated. '
138 'This can be attempted when used in a for loop or as a *args or **kwargs function argument.')
139
TraceError: Proxy object cannot be iterated. This can be attempted when used in a for loop or as a *args or **kwargs function argument.