I am implementing a custom function with autograd, but I find leakage in the running. I simplify and isolate the part below.
import torch
import gc
def get_tensors_in_memory():
tensors = []
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
tensors.append(obj)
except:
pass
return tensors
def func(y, dy):
g = G(y)
dywg = torch.autograd.grad(g, (y,) + tuple(G.parameters()), grad_outputs=dy)
dx = dywg[0]
for i, param in enumerate(G.parameters()):
param.grad = dywg[1+i]
x = y + g
return x, dx
torch.manual_seed(0)
G = torch.nn.Linear(3, 3)
y = torch.rand(3, requires_grad=True)
dy = torch.rand(3)
tensors_before = get_tensors_in_memory()
x, dx = func(y, dy)
del x, dx, G, y, dy
gc.collect()
tensors_after = get_tensors_in_memory()
leftover = list(set(tensors_after) - set(tensors_before))
print("There are %d leftover tensors" % len(leftover))
for t in leftover:
print(t)
As you can see, I tried to clean all objects I know, but the output says there is leftover
There are 1 leftover tensors
tensor([[0.0116, 0.0876, 0.1524],
[0.0156, 0.1178, 0.2050],
[0.0179, 0.1351, 0.2351]])
Furthermore, I find the leftover equals G.weight.grad. I don’t know why this remains in memory. Even stranger, G.bias.grad does not remain.
@yangzh This looks so entrenched in the graph coding
TRY 1
I was able to eliminate the leftover tensor if i did a small change in the code
If i create the nn.Linear inside the func then there are no leftover tensors
import torch
import gc
def get_tensors_in_memory():
tensors = []
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
tensors.append(obj)
except:
pass
return tensors
def func(y, dy):
G = torch.nn.Linear(3, 3)
g = G(y)
dywg = torch.autograd.grad(g, [y] + list(G.parameters()), grad_outputs=dy)
dx = dywg[0]
for i, param in enumerate(G.parameters()):
param.grad = dywg[1+i]
x = y + g
return x.detach(), dx.detach()
torch.manual_seed(0)
y = torch.rand(3, requires_grad=True)
dy = torch.rand(3)
tensors_before = get_tensors_in_memory()
x, dx = func(y, dy)
del x, dx, y, dy
gc.collect()
tensors_after = get_tensors_in_memory()
leftover = list(set(tensors_after) - set(tensors_before))
print("There are %d leftover tensors" % len(leftover))
for t in leftover:
print(t)
TRY 2
The second thing that i found out was that the tensor in question is neither present in globals() or locals(). But it is present in gc.get_objects(). Now gc.get_objects() gets its data from threadState which is a cython/C package
Thank you for your answer. I cannot use TRY 1 because I need the network outside the function. TRY 2 is an observation but not a solution. The leakage does exist, and I cannot simply neglect it. In my real code, the program works fine for the first few epochs, but memory consumption keeps increasing (can be detected in various ways), and after tens of epochs, the program crashes with an out-of-memory error.
Now I find a workaround by replacing torch.autograd.grad with g.backward(dy), which avoids extra references to the objects.