I have the following code
class MyModule(torch.nn.Module):
def __init__(self, param):
super().__init__()
# To simplify, self.param is a large scale tensor
# In my scenario, it will be a custom class, which is large and consumes 3-10GB cuda memory
self.param = param
def forward(self):
one = self.param
two = one + one
three = two + one
return one, two, three
param = torch.ones((1024, 1024, 768)).cuda()
f = MyModule(param)
torch.cuda.reset_peak_memory_stats()
print(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024)
gm = fx.symbolic_trace(f)
torch.cuda.reset_peak_memory_stats()
print(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024)
gm()
torch.cuda.reset_peak_memory_stats()
print(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024)
The output is
3.0 GB
9.0 GB
9.0 GB
For self.param, torch.fx treats it as a constant and then caches the results of all calculations of that constant, resulting in a very large memory overhead. Is there any way to stop this behavior?