Trying to load a jit model in python and do some inference, and getting a high memory usage that surprised me. It seems that:
- just importing torch adds 80MB of memory
- loading a model that is 30MB on disk adds 110MB of memory
- The first few model calls add about 300MB.
I wonder if there’s a way to reduce that? Ideally model calls should add no memory overhead for example.
I’m using python’s memory_profiler to get those numbers. Using torch v1.11:
Line # Mem usage Increment Occurences Line Contents
4 51.7 MiB 51.7 MiB 1 @profile
5 def foo():
6 62.9 MiB 11.2 MiB 1 import numpy as np
7 146.3 MiB 83.4 MiB 1 import torch
8 146.3 MiB 0.0 MiB 1 model_path = 'ckpt.pt'
9 257.1 MiB 110.8 MiB 1 model = torch.jit.load(model_path)
10
11 257.3 MiB 0.3 MiB 1 dummy_input = torch.tensor(np.random.default_rng().standard_normal([1, 150, 80]), dtype=torch.float32)
12 404.4 MiB 147.1 MiB 1 _ = model.encode(tokens=dummy_input)
13
14 404.4 MiB 0.0 MiB 1 dummy_input = torch.tensor(np.random.default_rng().standard_normal([1, 150, 80]), dtype=torch.float32)
15 490.7 MiB 86.2 MiB 1 _ = model.encode(tokens=dummy_input)
16
17 490.7 MiB 0.0 MiB 1 dummy_input = torch.tensor(np.random.default_rng().standard_normal([1, 150, 80]), dtype=torch.float32)
18 544.1 MiB 53.5 MiB 1 _ = model.encode(tokens=dummy_input)
19
20 544.1 MiB 0.0 MiB 1 dummy_input = torch.tensor(np.random.default_rng().standard_normal([1, 150, 80]), dtype=torch.float32)
21 546.0 MiB 1.9 MiB 1 _ = model.encode(tokens=dummy_input)
22
23 546.0 MiB 0.0 MiB 1 dummy_input = torch.tensor(np.random.default_rng().standard_normal([1, 150, 80]), dtype=torch.float32)
24 546.5 MiB 0.4 MiB 1 _ = model.encode(tokens=dummy_input)