# Understanding what consumes the memory in my for loop

Hello,

I am comparing two separate method implementations - the former (lossFunction_sinkhornlog) is supposed to be numerically more stable than the latter (lossFunction_sinkhorn) but both methods should theoretically return similar results.

def lossFunction_sinkhornlog(samples, labels, dist_mat, eps):
'''
samples is what is predicted by the network
labels is the target
'''

sampleSum=torch.sum(samples)
labelSum=torch.sum(labels)

a = samples.view(-1)/sampleSum
b = labels.view(-1)/labelSum

max_iter = 40

u = torch.zeros_like(a)
v = torch.zeros_like(b)
dist_mat_norm= dist_mat/dist_mat.max()

for i in range(max_iter):
u = eps * (torch.log(a) - (torch.logsumexp(Mf(dist_mat_norm, u, v, eps), dim=-1))) + u
v = eps * (torch.log(b) - (torch.logsumexp(torch.transpose(Mf(dist_mat_norm, u, v, eps), 1, 0), dim=-1))) + v

pi = torch.exp((-dist_mat_norm + u.unsqueeze(-1) + v.unsqueeze(-2))/eps)
loss = torch.sum(pi * dist_mat)
return loss


The second method implementation (lossFunction_sinkhorn) is provided below:

def lossFunction_sinkhorn(samples, labels, dist_mat,  eps):
'''
samples is what is predicted by the network
labels is the target
'''

sampleSum=torch.sum(samples)
labelSum=torch.sum(labels)

a = samples.view(-1)/sampleSum
b = labels.view(-1)/labelSum

max_iter = 1000

u = torch.ones_like(a)
v = torch.ones_like(b)
dist_mat_norm= dist_mat/dist_mat.max()
M = torch.exp(-dist_mat_norm/eps)

for i in range(max_iter):
u = torch.div(a, (torch.mv(M,v) ))
v = torch.div(b, (torch.mv(torch.transpose(M, 0, 1), u) ))

pi =  torch.mm(torch.mm(torch.diag_embed(v),M), torch.diag_embed(u))
loss = torch.sum(torch.transpose(pi, 1, 0)*dist_mat)
return loss


I notice though that while the former consumes ~8 GBs of memory on the GPU, the latter only requires <1 GB of GPU memory. Most of the code between the two methods is similar, except for the for loop, where the number of iterations becomes the bottleneck for the former. I am guessing that a copy of the computation graph is produced and retained for each loop iteration: would be grateful for any intuition on why the former implementation is consuming so much GPU memory! Thank you in advance.

The former method makes a call to another method (Mf) which is provided here below:

def Mf(C, u, v, epsilon):
"$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"

return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / epsilon


I don’t know which tensors require gradients and would thus be part of the computation graph.
I guess that one of u, v or dist_mat_norm might require gradients and due to the broadcasting in Mf the memory usage might blow up.

u = torch.randn(10, 10)
v = torch.randn(10, 10)
res = u.unsqueeze(-1) + v.unsqueeze(-2)
print(res.shape)
> torch.Size([10, 10, 10])


Could this be the reason?