I was expecting the only way to leak memory in python was to add endlessly more data to a structure.
I’m pretty sure I’m not doing that.
I am using pytorch-lightning and leaking memory on a line that I don’t see how it can be leaking.
First I enabled tracemalloc and create a callback class that will show memory allocations between epochs.
import tracemalloc
tracemalloc.start()
last_snapshot = False
class LeakDetection(Callback):
def on_train_epoch_end(self, trainer, pl_module):
global last_snapshot
new_snapshot = tracemalloc.take_snapshot()
if(last_snapshot):
top_stats = new_snapshot.compare_to(last_snapshot, 'lineno')
for stat in top_stats[:20]:
print(stat)
last_snapshot = new_snapshot
trainer = pl.Trainer( gpus=[0],
max_epochs=4,
limit_train_batches=1000,
limit_val_batches=10,
callbacks=[ LeakDetection()])
The above reports leaks on every epoch at
/transformer.py:713: size=106 MiB (+35.3 MiB), count=1540440 (+513402), average=72 B
/transformer.py:704: size=422 KiB (+141 KiB), count=6000 (+2000), average=72 B
/transformer.py:703: size=422 KiB (+141 KiB), count=6000 (+2000), average=72 B
The above is pointing to new klugy code I added to track accuracy. While I appreciate advice to make this more torch clean, I am really looking to fix my understand , mainly how can this even leak?
In my forward declaration, after losses are calculated, I added
def forward(self,...
.... calculate losses
# calculate x,y then locations then nearest then accuracy ... my new code
next_link_x_n = torch.atan2(next_link_n4[:,0], next_link_n4[:,1]) # line 703
next_link_y_n = torch.atan2(next_link_n4[:,2], next_link_n4[:,3]) # line 704
center_position_n4 = CenterPosition_NM4.reshape(N*M,4)
center_x_n = torch.atan2(center_position_n4[:,0], center_position_n4[:,1])
center_y_n = torch.atan2(center_position_n4[:,2], center_position_n4[:,3])
correct_predictions = 0
incorrect_predictions = 0
for ii in range(N*M):
if valid_next_link_n[ii]:
distance_n = torch.square(center_x_n - next_link_x_n[ii])+torch.square(center_y_n - next_link_y_n[ii]) # 713
distance_n[ii] = distance_n.max()+1 # do not want self
iii = torch.argmin(distance_n)
if (center_position_n4[iii,0]==NextSwcPosition_NM4.reshape(N*M,4)[ii,0]) and (center_position_n4[iii,1]==NextSwcPosition_NM4.reshape(N*M,4)[ii,1]) and (center_position_n4[iii,2]==NextSwcPosition_NM4.reshape(N*M,4)[ii,2]) and (center_position_n4[iii,3]==NextSwcPosition_NM4.reshape(N*M,4)[ii,3]):
correct_predictions+=1
else:
incorrect_predictions+=1
To be torch clean, at a minimum, I should have added detach to next_link_n4 or wrapped this new accuracy code in “with torch.no_grad():” since accuracy is not part of loss. Even so, how can these lines (703, 704 and 713) even leak?
I should note that this accuracy is logged
The following appears to eliminate the leak, but what is the mechanism where gradient information can result in a leak?
def forward(self,...
.... calculate losses
with torch.no_grad():
# calculate x,y then locations then nearest swcid then accuracy
next_link_x_n = torch.atan2(next_link_n4[:,0], next_link_n4[:,1])
next_link_y_n = torch.atan2(next_link_n4[:,2], next_link_n4[:,3])
center_position_n4 = CenterPosition_NM4.reshape(N*M,4)
center_x_n = torch.atan2(center_position_n4[:,0], center_position_n4[:,1])
center_y_n = torch.atan2(center_position_n4[:,2], center_position_n4[:,3])
correct_predictions = 0
incorrect_predictions = 0
for ii in range(N*M):
if valid_next_link_n[ii]:
distance_n = torch.square(center_x_n - next_link_x_n[ii])+torch.square(center_y_n - next_link_y_n[ii])
distance_n[ii] = distance_n.max()+1 # do not want self
iii = torch.argmin(distance_n)
if (center_position_n4[iii,0]==NextSwcPosition_NM4.reshape(N*M,4)[ii,0]) and (center_position_n4[iii,1]==NextSwcPosition_NM4.reshape(N*M,4)[ii,1]) and (center_position_n4[iii,2]==NextSwcPosition_NM4.reshape(N*M,4)[ii,2]) and (center_position_n4[iii,3]==NextSwcPosition_NM4.reshape(N*M,4)[ii,3]):
correct_predictions+=1
else:
incorrect_predictions+=1