Memory leaks at inference

I’m trying to run my model with Flask but I bumped into high memory consumption and eventually shutting down of server.

I started to profile my app to find a place with huge memory allocation and found it in model inference (if I comment my network inference then there’s no problems with a memory).
First inference:

Line #    Mem usage    Increment   Line Contents
================================================
    49    261.6 MiB    261.6 MiB       @profile
    50                                 def predict(self, img):
    51    261.6 MiB      0.0 MiB           with torch.no_grad():
    52    269.3 MiB      7.7 MiB               data = self.test_image_transforms(img)
    53    269.3 MiB      0.0 MiB               data = torch.unsqueeze(data, dim=0)
    54    269.3 MiB      0.0 MiB               data = data.to(self.device)
    55    442.1 MiB    172.8 MiB               logit = self.net(data)
    56    442.1 MiB      0.0 MiB               pred = torch.sigmoid(logit.cpu())[0][0].data.numpy()
    57    442.1 MiB      0.0 MiB               mask = pred >= 0.5
    58
    59
    60    442.1 MiB      0.0 MiB           return mask

It can be seen that there’s huge memory allocation on line 55 (127 MB) while total memory usage is 261.6 MB (before allocation).

Second inference (after 10 sec):

Line #    Mem usage    Increment   Line Contents
================================================
    49    374.4 MiB    374.4 MiB       @profile
    50                                 def predict(self, img):
    51    374.4 MiB      0.0 MiB           with torch.no_grad():
    52    380.6 MiB      6.2 MiB               data = self.test_image_transforms(img)
    53    380.6 MiB      0.0 MiB               data = torch.unsqueeze(data, dim=0)
    54    380.6 MiB      0.0 MiB               data = data.to(self.device)
    55    548.5 MiB    167.9 MiB               logit = self.net(data)
    56    548.5 MiB      0.0 MiB               pred = torch.sigmoid(logit.cpu())[0][0].data.numpy()
    57    548.5 MiB      0.0 MiB               mask = pred >= 0.5
    58
    59
    60    548.5 MiB      0.0 MiB           return mask

There’s total 375 MB MB allocated and so on every next inference.
Then I tried to manually deallocate the needless memory trying to delete output (del logit), call garbage collector but it didn’t help at all.

Then I went down to the forward method in which all the magic has to happen.

That’s a snapshot of the profiler at first inference:

 Line #    Mem usage    Increment   Line Contents
 ================================================
    116    269.3 MiB    269.3 MiB       @profile
    117                                 def forward(self, x):
    118    269.3 MiB      0.0 MiB           with torch.no_grad():
    119    269.3 MiB      0.0 MiB               h, w = x.size(2), x.size(3)
    120    317.6 MiB     48.3 MiB               f = self.base_network(x)
    121    317.6 MiB      0.0 MiB               p = self.psp(f)
    122    317.6 MiB      0.0 MiB               drop_1_out = self.drop_1(p)
    123    351.4 MiB     33.8 MiB               p = self.up_1(drop_1_out)
    124    351.4 MiB      0.0 MiB               p = self.drop_2(p)
    125
    126    364.3 MiB     12.9 MiB               p = self.up_2(p)
    127    364.3 MiB      0.0 MiB               p = self.drop_2(p)
    128
    129    396.3 MiB     32.0 MiB               p = self.up_3(p)
    130
    131    396.3 MiB      0.0 MiB               if (p.size(2) != h) or (p.size(3) != w):
    132    441.6 MiB     45.4 MiB                   p = F.interpolate(p, size=(h, w), mode='bilinear')
    133
    134    441.6 MiB      0.0 MiB               p = self.drop_2(p)
    135    487.0 MiB     45.4 MiB               r = self.final(p)
    136    487.0 MiB      0.0 MiB           return r

Here I also tried to delete results of layers but it also didn’t help except for deleting last Tensor p:

 Line #    Mem usage    Increment   Line Contents
 ================================================
    116    265.2 MiB    265.2 MiB       @profile
    117                                 def forward(self, x):
    118    265.2 MiB      0.0 MiB           with torch.no_grad():
    119    265.2 MiB      0.0 MiB               h, w = x.size(2), x.size(3)
    120    301.5 MiB     36.3 MiB               f = self.base_network(x)
    121    301.9 MiB      0.4 MiB               p = self.psp(f)
    122    301.9 MiB      0.0 MiB               drop_1_out = self.drop_1(p)
    123    320.8 MiB     19.0 MiB               p = self.up_1(drop_1_out)
    124    320.8 MiB      0.0 MiB               p = self.drop_2(p)
    125
    126    328.8 MiB      8.0 MiB               p = self.up_2(p)
    127    328.8 MiB      0.0 MiB               p = self.drop_2(p)
    128
    129    347.6 MiB     18.8 MiB               p = self.up_3(p)
    130
    131    347.6 MiB      0.0 MiB               if (p.size(2) != h) or (p.size(3) != w):
    132    373.7 MiB     26.0 MiB                   p = F.interpolate(p, size=(h, w), mode='bilinear')
    133
    134    373.7 MiB      0.0 MiB               p = self.drop_2(p)
    135    400.2 MiB     26.6 MiB               r = self.final(p)
    136
    137    374.6 MiB      0.0 MiB               del p
    138
    139
    140    374.6 MiB      0.0 MiB           return r

As it can be seen deleting of the tensor p released previous allocated 26.6 MB.

But if I try to delete another ones something strange happens:

Line #    Mem usage    Increment   Line Contents
================================================
   116    264.9 MiB    264.9 MiB       @profile
   117                                 def forward(self, x):
   118    264.9 MiB      0.0 MiB           with torch.no_grad():
   119    264.9 MiB      0.0 MiB               h, w = x.size(2), x.size(3)
   120    305.0 MiB     40.2 MiB               f = self.base_network(x)
   121    305.0 MiB      0.0 MiB               p = self.psp(f)
   122    305.0 MiB      0.0 MiB               drop_1_out = self.drop_1(p)
   123    323.2 MiB     18.2 MiB               p = self.up_1(drop_1_out)
   124    323.2 MiB      0.0 MiB               p = self.drop_2(p)
   125
   126    331.7 MiB      8.5 MiB               p = self.up_2(p)
   127    331.7 MiB      0.0 MiB               p = self.drop_2(p)
   128
   129    352.3 MiB     20.6 MiB               up_3_out = self.up_3(p)
   130
   131    352.3 MiB      0.0 MiB               if (up_3_out.size(2) != h) or (up_3_out.size(3) != w):
   132    382.3 MiB     29.9 MiB                   up_3_out = F.interpolate(up_3_out, size=(h, w), mode='bilinear')
   133
   134    382.3 MiB      0.0 MiB               drop_2_out = self.drop_2(up_3_out)
   135    412.2 MiB     29.9 MiB               r = self.final(drop_2_out)
   136
   137    412.2 MiB      0.0 MiB               del p
   138    412.2 MiB      0.0 MiB               del up_3_out
   139    382.9 MiB      0.0 MiB               del drop_2_out
   140
   141
   142    382.9 MiB      0.0 MiB           return r

As it can be seen only the last tensor is deleted.
Maybe somebody has any ideas how to delete the allocated memory…

3 Likes