Memory leaks at inference

I’m trying to run my model with Flask but I bumped into high memory consumption and eventually shutting down of server.

I started to profile my app to find a place with huge memory allocation and found it in model inference (if I comment my network inference then there’s no problems with a memory).
First inference:

Line #    Mem usage    Increment   Line Contents
================================================
    49    261.6 MiB    261.6 MiB       @profile
    50                                 def predict(self, img):
    51    261.6 MiB      0.0 MiB           with torch.no_grad():
    52    269.3 MiB      7.7 MiB               data = self.test_image_transforms(img)
    53    269.3 MiB      0.0 MiB               data = torch.unsqueeze(data, dim=0)
    54    269.3 MiB      0.0 MiB               data = data.to(self.device)
    55    442.1 MiB    172.8 MiB               logit = self.net(data)
    56    442.1 MiB      0.0 MiB               pred = torch.sigmoid(logit.cpu())[0][0].data.numpy()
    57    442.1 MiB      0.0 MiB               mask = pred >= 0.5
    58
    59
    60    442.1 MiB      0.0 MiB           return mask

It can be seen that there’s huge memory allocation on line 55 (127 MB) while total memory usage is 261.6 MB (before allocation).

Second inference (after 10 sec):

Line #    Mem usage    Increment   Line Contents
================================================
    49    374.4 MiB    374.4 MiB       @profile
    50                                 def predict(self, img):
    51    374.4 MiB      0.0 MiB           with torch.no_grad():
    52    380.6 MiB      6.2 MiB               data = self.test_image_transforms(img)
    53    380.6 MiB      0.0 MiB               data = torch.unsqueeze(data, dim=0)
    54    380.6 MiB      0.0 MiB               data = data.to(self.device)
    55    548.5 MiB    167.9 MiB               logit = self.net(data)
    56    548.5 MiB      0.0 MiB               pred = torch.sigmoid(logit.cpu())[0][0].data.numpy()
    57    548.5 MiB      0.0 MiB               mask = pred >= 0.5
    58
    59
    60    548.5 MiB      0.0 MiB           return mask

There’s total 375 MB MB allocated and so on every next inference.
Then I tried to manually deallocate the needless memory trying to delete output (del logit), call garbage collector but it didn’t help at all.

Then I went down to the forward method in which all the magic has to happen.

That’s a snapshot of the profiler at first inference:

 Line #    Mem usage    Increment   Line Contents
 ================================================
    116    269.3 MiB    269.3 MiB       @profile
    117                                 def forward(self, x):
    118    269.3 MiB      0.0 MiB           with torch.no_grad():
    119    269.3 MiB      0.0 MiB               h, w = x.size(2), x.size(3)
    120    317.6 MiB     48.3 MiB               f = self.base_network(x)
    121    317.6 MiB      0.0 MiB               p = self.psp(f)
    122    317.6 MiB      0.0 MiB               drop_1_out = self.drop_1(p)
    123    351.4 MiB     33.8 MiB               p = self.up_1(drop_1_out)
    124    351.4 MiB      0.0 MiB               p = self.drop_2(p)
    125
    126    364.3 MiB     12.9 MiB               p = self.up_2(p)
    127    364.3 MiB      0.0 MiB               p = self.drop_2(p)
    128
    129    396.3 MiB     32.0 MiB               p = self.up_3(p)
    130
    131    396.3 MiB      0.0 MiB               if (p.size(2) != h) or (p.size(3) != w):
    132    441.6 MiB     45.4 MiB                   p = F.interpolate(p, size=(h, w), mode='bilinear')
    133
    134    441.6 MiB      0.0 MiB               p = self.drop_2(p)
    135    487.0 MiB     45.4 MiB               r = self.final(p)
    136    487.0 MiB      0.0 MiB           return r

Here I also tried to delete results of layers but it also didn’t help except for deleting last Tensor p:

 Line #    Mem usage    Increment   Line Contents
 ================================================
    116    265.2 MiB    265.2 MiB       @profile
    117                                 def forward(self, x):
    118    265.2 MiB      0.0 MiB           with torch.no_grad():
    119    265.2 MiB      0.0 MiB               h, w = x.size(2), x.size(3)
    120    301.5 MiB     36.3 MiB               f = self.base_network(x)
    121    301.9 MiB      0.4 MiB               p = self.psp(f)
    122    301.9 MiB      0.0 MiB               drop_1_out = self.drop_1(p)
    123    320.8 MiB     19.0 MiB               p = self.up_1(drop_1_out)
    124    320.8 MiB      0.0 MiB               p = self.drop_2(p)
    125
    126    328.8 MiB      8.0 MiB               p = self.up_2(p)
    127    328.8 MiB      0.0 MiB               p = self.drop_2(p)
    128
    129    347.6 MiB     18.8 MiB               p = self.up_3(p)
    130
    131    347.6 MiB      0.0 MiB               if (p.size(2) != h) or (p.size(3) != w):
    132    373.7 MiB     26.0 MiB                   p = F.interpolate(p, size=(h, w), mode='bilinear')
    133
    134    373.7 MiB      0.0 MiB               p = self.drop_2(p)
    135    400.2 MiB     26.6 MiB               r = self.final(p)
    136
    137    374.6 MiB      0.0 MiB               del p
    138
    139
    140    374.6 MiB      0.0 MiB           return r

As it can be seen deleting of the tensor p released previous allocated 26.6 MB.

But if I try to delete another ones something strange happens:

Line #    Mem usage    Increment   Line Contents
================================================
   116    264.9 MiB    264.9 MiB       @profile
   117                                 def forward(self, x):
   118    264.9 MiB      0.0 MiB           with torch.no_grad():
   119    264.9 MiB      0.0 MiB               h, w = x.size(2), x.size(3)
   120    305.0 MiB     40.2 MiB               f = self.base_network(x)
   121    305.0 MiB      0.0 MiB               p = self.psp(f)
   122    305.0 MiB      0.0 MiB               drop_1_out = self.drop_1(p)
   123    323.2 MiB     18.2 MiB               p = self.up_1(drop_1_out)
   124    323.2 MiB      0.0 MiB               p = self.drop_2(p)
   125
   126    331.7 MiB      8.5 MiB               p = self.up_2(p)
   127    331.7 MiB      0.0 MiB               p = self.drop_2(p)
   128
   129    352.3 MiB     20.6 MiB               up_3_out = self.up_3(p)
   130
   131    352.3 MiB      0.0 MiB               if (up_3_out.size(2) != h) or (up_3_out.size(3) != w):
   132    382.3 MiB     29.9 MiB                   up_3_out = F.interpolate(up_3_out, size=(h, w), mode='bilinear')
   133
   134    382.3 MiB      0.0 MiB               drop_2_out = self.drop_2(up_3_out)
   135    412.2 MiB     29.9 MiB               r = self.final(drop_2_out)
   136
   137    412.2 MiB      0.0 MiB               del p
   138    412.2 MiB      0.0 MiB               del up_3_out
   139    382.9 MiB      0.0 MiB               del drop_2_out
   140
   141
   142    382.9 MiB      0.0 MiB           return r

As it can be seen only the last tensor is deleted.
Maybe somebody has any ideas how to delete the allocated memory…

3 Likes

Is the memory growing in each iteration until you run out of memory?
If so, could you post the model definition, so that we could debug it?

yes, in each one until the server crashes due to lack of RAM
the model arch is following:

import torch
import torch.nn.functional as F
from torch import nn
from torchvision.models import squeezenet1_1, resnet101
from torch.nn.init import xavier_normal_
from memory_profiler import profile


class SqueezeNetExtractor(nn.Module):
    def __init__(self):
        super(SqueezeNetExtractor, self).__init__()
        model = squeezenet1_1(pretrained=True)
        features = model.features
        self.feature1 = features[:2]
        self.feature2 = features[2:5]
        self.feature3 = features[5:8]
        self.feature4 = features[8:]

    def forward(self, x):
        f1 = self.feature1(x)
        f2 = self.feature2(f1)
        f3 = self.feature3(f2)
        f4 = self.feature4(f3)
        return f4


class PyramidPoolingModule(nn.Module):
    def __init__(self, in_channels, sizes=(1, 2, 3, 6)):
        super(PyramidPoolingModule, self).__init__()
        pyramid_levels = len(sizes)
        out_channels = in_channels // pyramid_levels

        pooling_layers = nn.ModuleList()
        for size in sizes:
            layers = [nn.AdaptiveAvgPool2d(size), nn.Conv2d(in_channels, out_channels, kernel_size=1)]
            pyramid_layer = nn.Sequential(*layers)
            pooling_layers.append(pyramid_layer)

        self.pooling_layers = pooling_layers

    def forward(self, x):
        h, w = x.size(2), x.size(3)
        features = [x]
        for pooling_layer in self.pooling_layers:
            # pool with different sizes
            pooled = pooling_layer(x)

            # upsample to original size
            upsampled = F.upsample(pooled, size=(h, w), mode='bilinear')

            features.append(upsampled)

        return torch.cat(features, dim=1)


class UpsampleLayer(nn.Module):
    def __init__(self, in_channels, out_channels, upsample_size=None):
        super().__init__()
        self.upsample_size = upsample_size

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        size = 2 * x.size(2), 2 * x.size(3)
        f = F.upsample(x, size=size, mode='bilinear')
        return self.conv(f)


class PSPNet(nn.Module):
    def __init__(self, num_class=1, sizes=(1, 2, 3, 6)):
        super(PSPNet, self).__init__()
        base_network = base_network.lower()
        self.base_network = SqueezeNetExtractor()
        feature_dim = 512
        self.psp = PyramidPoolingModule(in_channels=feature_dim, sizes=sizes)
        self.drop_1 = nn.Dropout2d(p=0.3)

        self.up_1 = UpsampleLayer(2*feature_dim, 256)
        self.up_2 = UpsampleLayer(256, 64)
        self.up_3 = UpsampleLayer(64, 64)

        self.drop_2 = nn.Dropout2d(p=0.15)
        self.final = nn.Sequential(
            nn.Conv2d(64, num_class, kernel_size=1)
        )

        self._init_weight()
        
    # @profile
    def forward(self, x):

            h, w = x.size(2), x.size(3)
            f = self.base_network(x)
            p = self.psp(f)
            p = self.drop_1(p)
            p = self.up_1(p)
            p = self.drop_2(p)

            p = self.up_2(p)
            p = self.drop_2(p)

            p = self.up_3(p)

            if (p.size(2) != h) or (p.size(3) != w):
                p = F.interpolate(p, size=(h, w), mode='bilinear')

            p = self.drop_2(p)

            return self.final(p)

    def _init_weight(self):
        layers = [self.up_1, self.up_2, self.up_3, self.final]
        for layer in layers:
            if isinstance(layer, nn.Conv2d):
                xavier_normal_(layer.weight.data)

            elif isinstance(layer, nn.BatchNorm2d):
                layer.weight.data.normal_(1.0, 0.02)
                layer.bias.data.fill_(0)

I cannot reproduce this issue using your model definition and this code:

model = PSPNet()
x = torch.randn(2, 3, 224, 224)

process = psutil.Process(os.getpid())
for idx in range(100):
    print(idx, process.memory_full_info().rss / 1024**2)
    out = model(x)

The memory info results in a usage between 202 to 513MB for 120 iterations:

0 202.50390625
1 341.28515625
2 427.9609375
3 477.71484375
4 429.0390625
5 449.1484375
6 449.1484375
7 449.1484375
8 473.1796875
9 447.20703125
10 454.21484375
11 425.265625
12 438.8125
13 462.953125
14 438.78125
15 450.0078125
16 470.640625
17 495.65234375
18 495.65234375
19 466.71484375
20 490.94921875
21 425.29296875
22 449.8359375
23 474.328125
24 498.0078125
25 474.984375
26 427.91796875
27 498.296875
28 498.296875
29 435.6875
30 459.5703125
31 485.3515625
32 426.7265625
33 449.15234375
34 449.15234375
35 426.125
36 450.09765625
37 425.78125
38 451.04296875
39 474.50390625
40 498.73828125
41 436.09375
42 451.3046875
43 489.19921875
44 489.19921875
45 489.19921875
46 460.77734375
47 485.265625
48 426.62890625
49 499.5859375
50 499.5859375
51 430.8125
52 455.55859375
53 431.5078125
54 493.12109375
55 444.44921875
56 466.10546875
57 434.92578125
58 513.296875
59 464.46875
60 464.46875
61 479.6796875
62 479.6796875
63 479.6796875
64 430.80859375
65 454.78125
66 504.28125
67 406.48046875
68 479.4375
69 479.4375
70 479.4375
71 479.4375
72 430.640625
73 503.85546875
74 503.85546875
75 454.6015625
76 479.60546875
77 479.60546875
78 479.60546875
79 484.24609375
80 437.57421875
81 510.7890625
82 462.0859375
83 481.9375
84 505.13671875
85 443.75
86 468.23828125
87 468.23828125
88 468.23828125
89 443.578125
90 516.53515625
91 467.83203125
92 516.296875
93 467.4765625
94 467.4765625
95 479.59375
96 430.65234375
97 455.140625
98 455.140625
99 492.5234375
100 492.5234375
101 418.58203125
102 491.796875
103 491.796875
104 423.875
105 471.82421875
106 423.0234375
107 449.31640625
108 423.9765625
109 448.20703125
110 473.47265625
111 448.171875
112 454.1015625
113 502.82421875
114 454.14453125
115 454.14453125
116 428.69921875
117 491.859375
118 443.1875
119 497.58203125

Two questions about your snippet:

  1. why usage grows to ~500 MB though first inference took 202 MB and every next inference does the same job?
    shouldn’t the allocated memory to be deallocated at each new iteration?

  2. why I get less and stable memory usage when move network run in another method:

model = PSPNet()
x = torch.randn(2, 3, 224, 224)
process = psutil.Process(os.getpid())

def run_network():
    out = model(x)

for idx in range(100):
    print(idx, process.memory_full_info().rss / 1024**2)
    run_network()

result:

1 328.0234375
2 344.46484375
3 354.515625
4 344.19921875
5 344.16015625
6 343.96484375
7 344.02734375
8 344.08984375
9 344.17578125
10 343.98828125
11 344.05078125
12 344.1171875
13 344.1796875
14 343.984375
15 344.046875
16 344.1328125
17 344.19140625
18 344.25390625
19 344.05859375
20 344.12109375
21 344.18359375
22 344.24609375
23 344.05078125
24 344.11328125
25 344.17578125
26 344.23828125
27 344.04296875
28 344.10546875
29 344.16796875
30 344.23046875
31 344.03515625
32 344.12890625
33 344.19140625
34 344.25390625
35 344.05859375
36 344.1484375
37 344.22265625
38 344.2890625
39 344.1171875
40 344.1796875
41 344.2421875
42 344.3046875
43 344.109375
44 344.171875
45 344.234375
46 344.296875
47 344.359375
48 344.16015625
49 344.21875
50 344.28125
51 344.34375
52 344.1484375
53 344.2109375
54 344.2734375
55 344.3359375
56 344.140625
57 344.203125
58 344.265625
59 344.328125
60 344.1328125
61 344.1953125
62 344.2578125
63 349.5078125
64 344.1875
65 344.1875
66 344.25
67 349.5
68 344.1796875
69 344.1796875
70 344.2421875
71 344.3046875
72 344.109375
73 344.171875
74 344.234375
75 344.296875
76 344.359375
77 344.1640625
78 344.2265625
79 349.4765625
80 344.4140625
81 344.15625
82 344.21875
83 344.28125
84 344.34375
85 344.171875
86 344.234375
87 344.296875
88 344.359375
89 344.1640625
90 344.2265625
91 344.2890625
92 344.34765625
93 344.15234375
94 344.21484375
95 344.27734375
96 344.33984375
97 344.14453125
98 344.20703125
99 349.45703125

also I noticed that the following code:

def run_net(self):
    x = torch.randn(2, 3, 480, 480)
    logit = self.net(x)

in the first iterations allocated memory both for the x and when run model:

 Line #    Mem usage    Increment   Line Contents
 ================================================
     50    252.8 MiB    252.8 MiB       @profile
     51                                 def run_net(self):
     52    258.3 MiB      5.5 MiB           x = torch.randn(2, 3, 480, 480)
     53    668.6 MiB    410.3 MiB           logit = self.net(x)

total memory is 252MB, 5.5 for x and 410 for inference
while after n iterations the following scene appears:

Line #    Mem usage    Increment   Line Contents
================================================
    50   1001.2 MiB   1001.2 MiB       @profile
    51                                 def run_net(self):
    52   1001.2 MiB      0.0 MiB           x = torch.randn(2, 3, 480, 480)
    53   1217.6 MiB    216.5 MiB           logit = self.net(x)

so for x memory is not allocated at all and in inference time it allocated only half of initial volume but total memory is now 1001 MB.
I suppose that’s the same nature as if some kind of caching is happening

Python uses function scoping, which frees all variables which are only used in the function scope.
Your memory footprint should therefore be lower, as e.g. out will be deleted and with it the intermediate tensors, which were created in the forward method. My code snippet doesn’t use it and thus the “first” out tensor with the computation graph is still in memory while the second iteration is running.

In your second example, the memory might just be reused. I’m not sure how the CPU memory allocation works in Python and PyTorch in particular. For GPU memory we use a custom caching allocator, which reuses memory if possible without reallocating.

But the memory reusing is not an appropriate behavior for me cause for also run a detectron2 and it achieves 5Gb of memory and my server crashes.
Can I avoid the behavior (reusing) and with every run to use only minimal amount of memory like it was in the first example in my previous reply?

I don’t know, how the CPU allocation is handled, so I cannot be of much help here.

But have you any idea about difference in the next two samples?
First one:

x = torch.randn(1, 3, 333, 332)

process = psutil.Process(os.getpid())
model = PSPNet()

@profile
def run_network(idx):
    mem = process.memory_full_info().rss / 1024**2
    print(idx, mem)
    out = model(x)
    # del out

def main():
    for idx in range(40):
        run_network(idx)

main()

Its output has the next repetitive output:

Line #    Mem usage    Increment   Line Contents
================================================
   132    354.2 MiB    354.2 MiB   @profile
   133                             def run_network(idx):
   134    354.2 MiB      0.0 MiB       mem = process.memory_full_info().rss / 1024**2
   135    354.2 MiB      0.0 MiB       print(idx, mem)
   136    433.0 MiB     78.8 MiB       out = model(x)
   137    366.8 MiB      0.0 MiB       del out

that is the out of network is really deleted;

the sample from my server:

    @profile
    def predict(self, img):

        with torch.no_grad():
            data = torch.randn(1, 3, 333, 332)

            logit = self.net(data)
            pred = torch.sigmoid(logit.cpu())[0][0].data.numpy()
            mask = pred >= 0.5

        return None

which in turn doesn’t lead to memory cleaning but just the opposite (up to a certain size as we saw)
its output:

 Line #    Mem usage    Increment   Line Contents
 ================================================
     55    560.0 MiB    560.0 MiB       @profile
     56                                 def predict(self, img):
     57
     58    560.0 MiB      0.0 MiB           with torch.no_grad():
     59                                         # data = self.test_image_transforms(img)
     60                                         # data = torch.unsqueeze(data, dim=0)
     61                                         # data = data.to(self.device)
     62    560.0 MiB      0.0 MiB               data = torch.randn(1, 3, 333, 332)
     63
     64    685.0 MiB    125.0 MiB               logit = self.net(data)
     65
     66    685.0 MiB      0.0 MiB               del logit
     67                                         # pred = torch.sigmoid(logit.cpu())[0][0].data.numpy()
     68                                         # mask = pred >= 0.5
     69
     70
     71    685.0 MiB      0.0 MiB           return None

Since the first example seems to work as intended, I would recommend to try to implement your second work flow using the same code logic. Currently, I’m not sure how the second code is called, as it seems to be wrapped in a Python class.
If your memory usage is increasing in each iteration using the second code snippet, you are probably storing some tensors or references to it. Note that the gc doesn’t immediately free the deleted objects, so you could experiment with gc.set_threshold for the different levels.
It should however not run out of memory.

I found out the reason of the memory growing…
It happens when inputs have different sizes.

The following code is with detectron2 but previous model works in the same way.

import detectron2
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import ColorMode
from glob import glob
setup_logger()

import os
import numpy as np
import cv2
import random
import uuid
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.data.catalog import Metadata
from memory_profiler import profile
import torch
import numpy as np
import psutil


CLASSES = ['short_sleeved_shirt', 'long_sleeved_shirt', 'short_sleeved_outwear', 'long_sleeved_outwear', 'vest', 'sling', 'shorts', 'trousers', 'skirt', 'short_sleeved_dress', 'long_sleeved_dress', 'vest_dress', 'sling_dress']


cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml"))


metadata = Metadata()
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(CLASSES)

metadata.set(thing_classes=CLASSES)

cfg.MODEL.WEIGHTS = os.path.join('/home/algernone/DNNS/dpf2/model_final_40000.pth')
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.55  # set the testing threshold for this model
cfg.MODEL.DEVICE = 'cpu'
predictor = DefaultPredictor(cfg)

process = psutil.Process(os.getpid())

for i in range(10):
    for img_name in glob('content/*.jpg'):
        im = cv2.imread(img_name)
        outputs = predictor(im)
        print(im.shape, process.memory_full_info().rss / 1024**2)       

5 runs for one image:

(550, 367, 3) 1565.56640625
(550, 367, 3) 1684.4140625
(550, 367, 3) 1714.15625
(550, 367, 3) 1743.78125
(550, 367, 3) 1565.8203125

5 for another:

(1800, 3200, 3) 1583.125
(1800, 3200, 3) 1687.15234375
(1800, 3200, 3) 1671.32421875
(1800, 3200, 3) 1600.40234375
(1800, 3200, 3) 1567.5

it is strange that for a larger image it took less memory…

and 5 runs for both images:
(550, 367, 3) 1580.203125
(1800, 3200, 3) 2706.109375
(550, 367, 3) 2690.5078125
(1800, 3200, 3) 2785.1953125
(550, 367, 3) 2798.828125
(1800, 3200, 3) 2862.90625
(550, 367, 3) 2876.65234375
(1800, 3200, 3) 2690.60546875
(550, 367, 3) 2690.66796875
(1800, 3200, 3) 2785.04296875

the memory increased significantly compared to inferences for these images separately
but if I don’t stick to the same size then this does not solve my problem

@Chame_call
Would this thread help you? Somewhat similar problem was solved for us by setting LRU_CACHE_CAPACITY=1 environment variable.

1 Like

Could you explain the meaning of the variable?

Setting environment variable LD_PRELOAD with the aim of loading jemalloc instead of default CPU allocator solved the problem.

My launch is as follows:
LD_PRELOAD=./libjemalloc.so.1 python3 app.py.

Related links:

4 Likes

@Chame_call

So you set this environmental variable prior to running your script?

I am scratching my head on a memory leak, where CPU memory consumption grows during inference. Not sure why data in the image’s loaded are accumulating.