Constants consume too much gpu memory in torch.fx

I have the following code

class MyModule(torch.nn.Module):

    def __init__(self, param):
        super().__init__()
        # To simplify, self.param is a large scale tensor
        # In my scenario, it will be a custom class, which is large and consumes 3-10GB cuda memory
        self.param = param 
 
    def forward(self):
        one = self.param
        two = one + one
        three = two + one
        return one, two, three

param = torch.ones((1024, 1024, 768)).cuda()
f = MyModule(param)

torch.cuda.reset_peak_memory_stats()
print(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024)

gm = fx.symbolic_trace(f)

torch.cuda.reset_peak_memory_stats()
print(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024)

gm()

torch.cuda.reset_peak_memory_stats()
print(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024)

The output is

3.0 GB
9.0 GB
9.0 GB

For self.param, torch.fx treats it as a constant and then caches the results of all calculations of that constant, resulting in a very large memory overhead. Is there any way to stop this behavior?

In short, how to trace a global var in torch.fx?

It seems that the only solution is set param as forward args rather module attr.