I used share_memory() to share model in subprocess. In child process, I forward a random tensor in cpu and after that the memory usage increase from 57M to 87M at the first forward time. After that the memory still gradually increase. What’s the reason for that? Below is the code.
Environment:
- Ubuntu 16.04
- python 3.6.5
- pytorch 1.0.1.post2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import time
import argparse
import logging
import os
import psutil
import torch
import torch.nn as nn
import torch.multiprocessing as mp
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 256, 3, 1, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.Conv2d(256, 512, 3, 1, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(),
nn.Conv2d(512, 512, 3, 1, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(),
nn.Conv2d(512, 512, 3, 1, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(),
)
def forward(self, x):
return self.conv(x)
class SimpleAgent(object):
def __init__(self, model):
self.model = model.eval()
def test(self):
x = torch.randn(1,3, 16, 8, device=torch.device("cpu"))
print("Before forward: {}".format(psutil.Process(os.getpid()).memory_info().rss))
with torch.no_grad():
outputs = self.model(x)
print("After forward: {}".format(psutil.Process(os.getpid()).memory_info().rss))
return outputs
def run_agent(model):
agent = SimpleAgent(model)
while True:
agent.test()
time.sleep(5)
if __name__ == "__main__":
logging.basicConfig(
handlers=[logging.StreamHandler()],
level=logging.INFO,
format="[%(asctime)s] %(filename)s: [%(levelname)s] %(message)s",
)
device = torch.device("cpu")
print("Before creat model: {}".format(psutil.Process(os.getpid()).memory_info().rss))
model = Model().to(device)
print("After creat model: {}".format(psutil.Process(os.getpid()).memory_info().rss))
model.share_memory()
process_list = [
mp.Process(
target=run_agent,
args=(model,),
)
for _ in range(1)
]
print("len list: {}".format(len(process_list)))
for i, p in enumerate(process_list):
print("start {}".format(i))
p.start()
print("Init Successful")
for p in process_list:
p.join()