High memory usage for inference

Trying to load a jit model in python and do some inference, and getting a high memory usage that surprised me. It seems that:

  • just importing torch adds 80MB of memory
  • loading a model that is 30MB on disk adds 110MB of memory
  • The first few model calls add about 300MB.
    I wonder if there’s a way to reduce that? Ideally model calls should add no memory overhead for example.

I’m using python’s memory_profiler to get those numbers. Using torch v1.11:

Line # Mem usage Increment Occurences Line Contents

 4     51.7 MiB     51.7 MiB           1   @profile
 5                                         def foo():
 6     62.9 MiB     11.2 MiB           1       import numpy as np
 7    146.3 MiB     83.4 MiB           1       import torch
 8    146.3 MiB      0.0 MiB           1       model_path = 'ckpt.pt'
 9    257.1 MiB    110.8 MiB           1       model = torch.jit.load(model_path)
10                                         
11    257.3 MiB      0.3 MiB           1       dummy_input = torch.tensor(np.random.default_rng().standard_normal([1, 150, 80]), dtype=torch.float32)
12    404.4 MiB    147.1 MiB           1       _ = model.encode(tokens=dummy_input)
13                                         
14    404.4 MiB      0.0 MiB           1       dummy_input = torch.tensor(np.random.default_rng().standard_normal([1, 150, 80]), dtype=torch.float32)
15    490.7 MiB     86.2 MiB           1       _ = model.encode(tokens=dummy_input)
16                                         
17    490.7 MiB      0.0 MiB           1       dummy_input = torch.tensor(np.random.default_rng().standard_normal([1, 150, 80]), dtype=torch.float32)
18    544.1 MiB     53.5 MiB           1       _ = model.encode(tokens=dummy_input)
19                                         
20    544.1 MiB      0.0 MiB           1       dummy_input = torch.tensor(np.random.default_rng().standard_normal([1, 150, 80]), dtype=torch.float32)
21    546.0 MiB      1.9 MiB           1       _ = model.encode(tokens=dummy_input)
22                                         
23    546.0 MiB      0.0 MiB           1       dummy_input = torch.tensor(np.random.default_rng().standard_normal([1, 150, 80]), dtype=torch.float32)
24    546.5 MiB      0.4 MiB           1       _ = model.encode(tokens=dummy_input)

If I understand the output correctly, the used memory is shown left and the delta in the right column?
If so, could you explain your concern a bit more as it seems the memory usage stabilizes after the initial warmup. Assuming you are using TorchScript then note that the first few passes would perform optimizations and the memory usage could change. I also don’t know if the memory is actually used or in a cache etc.

You understand the columns correctly. My concern is that I can’t use the TorchScript model for inference on end-devices with limited memory, such as a raspberry pi.

I wonder if there’s a way around it, or what’s the recommended way to do inference on such devices. Ideally for me model inference would be stateless, meaning that all memory used for the forward pass is freed after the forward pass is completed.

Raising this one again… Can anyone recommend how to do a small memory footprint inference?