Memory waste in using GPU

hi, I’m working on a project that I have to calculate k-nearest neighborhood for every input points for each iteration.
indeed naive pair-wise distance implementation is memory consuming but I just want it kept simple

here’s the code I wrote :

def Knn(ref_point, query_point, k):

# reference point : [B, N, d]
# query point : [B, M, d]
# knn index : [B, N, M]

[B, N, d] = ref_point.shape
[B, M, d] = query_point.shape

rep_query = query_point.unsqueeze(-3).expand(B, N, M, d).contiguous().view(B, N*M, d) # [B, N*M, d]
repelem_ref_point = ref_point.unsqueeze(-2).expand(B, N, M, d).contiguous().view(B, N*M, d) # [B, N*M, d]

pw_dist = torch.sqrt(torch.sum((rep_query-repelem_ref_point)**2, dim=-1)).view(B, M, N)
[value, idx] = torch.sort(pw_dist, dim=-1)
knn_idx = idx[:,:,1:k+1] # [B, M, K]    

del rep_query
del repelem_ref_point
del pw_dist

return knn_idx

the problem is that after passing through this function memory for variables that are used for calculating k-nn such as rep_query, repelem_ref_point are still left in the GPU memory after resulting k-nn indices are returned

in C++ and many other languages variables that are inside the function are removed after returning value from a function

but torch doesn’t seem to do so
here’s a simple experiment to verify that

idx = Knn(torch.ones((32, 1000, 3), requires_grad=False).cuda(), torch.ones((32, 1000, 3), requires_grad=False).cuda(), k=2)
del idx

printed output is :

even though I tried to remove unnecessary variables and empty the cache memory, memory was deallocated only after removing the output of a function: idx

this knn operation is only done once in my network still, it consumes most of my gpu memory
please give me some advice on this problem