I read several posts online about memory leak, could not find a solution. thanks in advance for any help. I also posted here and make a link so chances are higher for other people facing the same problem can find the solution. [memory leak] [PyTorch] .backward(create_graph=True) · Issue #7343 · pytorch/pytorch · GitHub
below is the distilled code that give memory leak with autogrid create_graph=True
torch.manual_seed(2332)
dim = 10
winit = torch.randn([dim])
w1 = torch.randn([dim+1],requires_grad=True)
w2 = w1[dim:]
torch.autograd.set_detect_anomaly(True)
for e in range(1000):
param0 = winit.clone().detach().requires_grad_()
p = torch.sum((param0-w2)*(param0-w2))
param0grad = torch.autograd.grad(p,param0,create_graph=True)[0] #mem leak
param1 = param0 - 0.1*param0grad
del param1 # del to try to free computational tree
del param0grad
del p
del param0
gc.collect()
if (e%100==0):
p = psutil.Process()
mem_info = p.memory_info()
print('epoch ',e,' mem info ',mem_info)
========
run time output:
epoch 0 mem info pmem(rss=95322112, vms=4828065792, pfaults=24924, pageins=135)
epoch 100 mem info pmem(rss=95776768, vms=4828065792, pfaults=25040, pageins=135)
epoch 200 mem info pmem(rss=96342016, vms=4830425088, pfaults=25178, pageins=135)
epoch 300 mem info pmem(rss=96907264, vms=4830687232, pfaults=25316, pageins=135)
epoch 400 mem info pmem(rss=97456128, vms=4831997952, pfaults=25450, pageins=135)
epoch 500 mem info pmem(rss=97988608, vms=4832260096, pfaults=25580, pageins=135)
thank you.