torch.fx cannot trace the model properly if there are manipulations of torch.Size() types.
This is the example to try to reconstruct a size type object from the input tensor’s first and second dimensions.
import torch
import torch.nn as nn
import torch.fx as fx
class Example (nn.Module):
def __init__(self) -> None:
super(Example, self).__init__()
def forward(self, input):
batch_size = input.size(0)
ch_size = input.size(1)
return torch.Size([batch_size, ch_size])
if __name__ == "__main__":
model = Example()
y = model(torch.randn(16, 5, 3, 3))
print(y)
fx_model = fx.symbolic_trace(model)
fx_graph = fx_model.graph
fx_graph.print_tabular()
print(fx_model.code)
The terminal output here:
torch.Size([16, 5])
Traceback (most recent call last):
File "trace_bug.py", line 23, in <module>
fx_model = fx.symbolic_trace(model)
File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/_symbolic_trace.py", line 907, in symbolic_trace
graph = tracer.trace(root, concrete_args)
File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/_symbolic_trace.py", line 615, in trace
self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
File "trace_bug.py", line 14, in forward
return torch.Size([batch_size, ch_size])
TypeError: torch.Size() takes an iterable of 'int' (item 0 is 'Proxy')