How can using variables with gradient where variable with gradient is not used for loss lead to memory leak?

I was expecting the only way to leak memory in python was to add endlessly more data to a structure.
I’m pretty sure I’m not doing that.

I am using pytorch-lightning and leaking memory on a line that I don’t see how it can be leaking.

First I enabled tracemalloc and create a callback class that will show memory allocations between epochs.

import tracemalloc

tracemalloc.start()
last_snapshot = False

class LeakDetection(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        global last_snapshot

        new_snapshot = tracemalloc.take_snapshot()
        if(last_snapshot):
            top_stats = new_snapshot.compare_to(last_snapshot, 'lineno')
    
            for stat in top_stats[:20]:
                print(stat)

        last_snapshot = new_snapshot

trainer = pl.Trainer( gpus=[0],  
        max_epochs=4, 
        limit_train_batches=1000,
        limit_val_batches=10,
       callbacks=[ LeakDetection()])

The above reports leaks on every epoch at

/transformer.py:713: size=106 MiB (+35.3 MiB), count=1540440 (+513402), average=72 B                     
/transformer.py:704: size=422 KiB (+141 KiB), count=6000 (+2000), average=72 B
/transformer.py:703: size=422 KiB (+141 KiB), count=6000 (+2000), average=72 B

The above is pointing to new klugy code I added to track accuracy. While I appreciate advice to make this more torch clean, I am really looking to fix my understand , mainly how can this even leak?

In my forward declaration, after losses are calculated, I added

def forward(self,...
   .... calculate losses
   # calculate x,y then locations then nearest then accuracy ... my new code
   next_link_x_n = torch.atan2(next_link_n4[:,0], next_link_n4[:,1]) # line 703
   next_link_y_n = torch.atan2(next_link_n4[:,2], next_link_n4[:,3]) # line 704
   center_position_n4 = CenterPosition_NM4.reshape(N*M,4)
   center_x_n = torch.atan2(center_position_n4[:,0], center_position_n4[:,1])
   center_y_n = torch.atan2(center_position_n4[:,2], center_position_n4[:,3])

   correct_predictions = 0
   incorrect_predictions = 0
   for ii in range(N*M):
                if valid_next_link_n[ii]:
                    distance_n = torch.square(center_x_n - next_link_x_n[ii])+torch.square(center_y_n - next_link_y_n[ii]) # 713
                    distance_n[ii] = distance_n.max()+1 # do not want self
                    iii = torch.argmin(distance_n)
                    if (center_position_n4[iii,0]==NextSwcPosition_NM4.reshape(N*M,4)[ii,0]) and (center_position_n4[iii,1]==NextSwcPosition_NM4.reshape(N*M,4)[ii,1]) and (center_position_n4[iii,2]==NextSwcPosition_NM4.reshape(N*M,4)[ii,2]) and (center_position_n4[iii,3]==NextSwcPosition_NM4.reshape(N*M,4)[ii,3]):
                        correct_predictions+=1
                    else:
                        incorrect_predictions+=1

To be torch clean, at a minimum, I should have added detach to next_link_n4 or wrapped this new accuracy code in “with torch.no_grad():” since accuracy is not part of loss. Even so, how can these lines (703, 704 and 713) even leak?

I should note that this accuracy is logged

The following appears to eliminate the leak, but what is the mechanism where gradient information can result in a leak?

def forward(self,...
   .... calculate losses
            with torch.no_grad():
                # calculate x,y then locations then nearest swcid then accuracy
                next_link_x_n = torch.atan2(next_link_n4[:,0], next_link_n4[:,1]) 
                next_link_y_n = torch.atan2(next_link_n4[:,2], next_link_n4[:,3]) 
                center_position_n4 = CenterPosition_NM4.reshape(N*M,4)
                center_x_n = torch.atan2(center_position_n4[:,0], center_position_n4[:,1])
                center_y_n = torch.atan2(center_position_n4[:,2], center_position_n4[:,3])

                correct_predictions = 0
                incorrect_predictions = 0
                for ii in range(N*M):
                    if valid_next_link_n[ii]:
                        distance_n = torch.square(center_x_n - next_link_x_n[ii])+torch.square(center_y_n - next_link_y_n[ii])
                        distance_n[ii] = distance_n.max()+1 # do not want self
                        iii = torch.argmin(distance_n)
                        if (center_position_n4[iii,0]==NextSwcPosition_NM4.reshape(N*M,4)[ii,0]) and (center_position_n4[iii,1]==NextSwcPosition_NM4.reshape(N*M,4)[ii,1]) and (center_position_n4[iii,2]==NextSwcPosition_NM4.reshape(N*M,4)[ii,2]) and (center_position_n4[iii,3]==NextSwcPosition_NM4.reshape(N*M,4)[ii,3]):
                            correct_predictions+=1
                        else:
                            incorrect_predictions+=1