Understanding what consumes the memory in my for loop

Hello,

I am comparing two separate method implementations - the former (lossFunction_sinkhornlog) is supposed to be numerically more stable than the latter (lossFunction_sinkhorn) but both methods should theoretically return similar results.

def lossFunction_sinkhornlog(samples, labels, dist_mat, eps):
    '''
    samples is what is predicted by the network
    labels is the target
    '''

    sampleSum=torch.sum(samples) 
    labelSum=torch.sum(labels)
    
    a = samples.view(-1)/sampleSum 
    b = labels.view(-1)/labelSum 
    
    max_iter = 40
    
    u = torch.zeros_like(a)        
    v = torch.zeros_like(b)        
    dist_mat_norm= dist_mat/dist_mat.max() 
    
    for i in range(max_iter):
        u = eps * (torch.log(a) - (torch.logsumexp(Mf(dist_mat_norm, u, v, eps), dim=-1))) + u
        v = eps * (torch.log(b) - (torch.logsumexp(torch.transpose(Mf(dist_mat_norm, u, v, eps), 1, 0), dim=-1))) + v
        
    pi = torch.exp((-dist_mat_norm + u.unsqueeze(-1) + v.unsqueeze(-2))/eps) 
    loss = torch.sum(pi * dist_mat)
    return loss

The second method implementation (lossFunction_sinkhorn) is provided below:

def lossFunction_sinkhorn(samples, labels, dist_mat,  eps):
     '''
    samples is what is predicted by the network
    labels is the target
    '''

    sampleSum=torch.sum(samples) 
    labelSum=torch.sum(labels)
    
    a = samples.view(-1)/sampleSum 
    b = labels.view(-1)/labelSum 

    max_iter = 1000
   
    u = torch.ones_like(a)        
    v = torch.ones_like(b)        
    dist_mat_norm= dist_mat/dist_mat.max() 
    M = torch.exp(-dist_mat_norm/eps) 
    
    for i in range(max_iter): 
        u = torch.div(a, (torch.mv(M,v) ))
        v = torch.div(b, (torch.mv(torch.transpose(M, 0, 1), u) ))
        
    pi =  torch.mm(torch.mm(torch.diag_embed(v),M), torch.diag_embed(u))
    loss = torch.sum(torch.transpose(pi, 1, 0)*dist_mat)
    return loss 

I notice though that while the former consumes ~8 GBs of memory on the GPU, the latter only requires <1 GB of GPU memory. Most of the code between the two methods is similar, except for the for loop, where the number of iterations becomes the bottleneck for the former. I am guessing that a copy of the computation graph is produced and retained for each loop iteration: would be grateful for any intuition on why the former implementation is consuming so much GPU memory! Thank you in advance.

The former method makes a call to another method (Mf) which is provided here below:

def Mf(C, u, v, epsilon):
    "Modified cost for logarithmic updates"
    "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"

    return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / epsilon  

I don’t know which tensors require gradients and would thus be part of the computation graph.
I guess that one of u, v or dist_mat_norm might require gradients and due to the broadcasting in Mf the memory usage might blow up.

u = torch.randn(10, 10)
v = torch.randn(10, 10)
res = u.unsqueeze(-1) + v.unsqueeze(-2)
print(res.shape)
> torch.Size([10, 10, 10])

Could this be the reason?