I was trying to test torch.multiprocessing
and I noticed a pretty strange memory profile. Here’s a minimum reproducible example:
import torch
import torch.distributed as dist
import time
import os
from torch.multiprocessing import Pipe, Process
_WORLD_SIZE = 2
_RANGE = 1000000
def ping_it(rank, pipe, device, XX):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '7044'
# dist.init_process_group("gloo", rank=rank, world_size=_WORLD_SIZE)
for i in range(_RANGE):
print(rank, i)
if i > 0:
X = pipe.recv()
X = X.to('cuda:{}'.format(device))
else:
XX = XX.to('cuda:{}'.format(device))
pipe.send(XX)
if __name__ == "__main__":
c1, c2 = Pipe()
X1 = torch.randn(10000, 10000)
X2 = torch.randn(10000, 10000)
p1 = Process(target=ping_it, args=(0, c1, 1, X1))
p2 = Process(target=ping_it, args=(1, c2, 2, X2))
p1.start()
p2.start()
time.sleep(30)
p1.join()
p2.join()
I also kicked off a memory logger as so:
import os
import time
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--file", type=str, default="check.log")
parser.add_argument("--interval", type=int, default=60)
args = parser.parse_args()
def get_memory():
tim = time.time()
ram = os.popen('free -m').readlines()[1].split()[2]
return tim, ram
if __name__ == "__main__":
with open(args.file, 'a') as f:
f.write("Time, RAM")
while True:
tim, ram = get_memory()
print(tim, ram)
f.write("{}, {}".format(tim, ram))
time.sleep(args.interval)
Here’s the memory profile I got:
Im leaking anywhere from 0.2-0.4MB/s in different runs. Any idea how I could fix this or if there’s anything to fix?