Torch.fx cannot trace torch.Size() properly

torch.fx cannot trace the model properly if there are manipulations of torch.Size() types.
This is the example to try to reconstruct a size type object from the input tensor’s first and second dimensions.

import torch
import torch.nn as nn
import torch.fx as fx

class Example (nn.Module):
    def __init__(self) -> None:
        super(Example, self).__init__()

    def forward(self, input):

        batch_size = input.size(0)
        ch_size = input.size(1)

        return torch.Size([batch_size, ch_size])

if __name__ == "__main__":

    model = Example()
    y = model(torch.randn(16, 5, 3, 3))
    print(y)

    fx_model = fx.symbolic_trace(model)
    fx_graph = fx_model.graph

    fx_graph.print_tabular()
    print(fx_model.code)

The terminal output here:

torch.Size([16, 5])
Traceback (most recent call last):
  File "trace_bug.py", line 23, in <module>
    fx_model = fx.symbolic_trace(model)
  File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/_symbolic_trace.py", line 907, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/_symbolic_trace.py", line 615, in trace
    self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
  File "trace_bug.py", line 14, in forward
    return torch.Size([batch_size, ch_size])
TypeError: torch.Size() takes an iterable of 'int' (item 0 is 'Proxy')

I am a freshman to FX. Maybe fx.wrap can help you.

@fx.wrap
def get_size(x, y):
    return torch.Size([x, y])

class Example (nn.Module):
    def __init__(self) -> None:
        super(Example, self).__init__()

    def forward(self, input):

        batch_size = input.size(0)
        ch_size = input.size(1)
        # return torch.Size([batch_size, ch_size])
        return get_size(batch_size, ch_size)

From the documentation, torch.fx marks the method as leaf method, which will be recorded in the graph, instead of being traced.