Hello,
I am comparing two separate method implementations - the former (lossFunction_sinkhornlog
) is supposed to be numerically more stable than the latter (lossFunction_sinkhorn
) but both methods should theoretically return similar results.
def lossFunction_sinkhornlog(samples, labels, dist_mat, eps):
'''
samples is what is predicted by the network
labels is the target
'''
sampleSum=torch.sum(samples)
labelSum=torch.sum(labels)
a = samples.view(-1)/sampleSum
b = labels.view(-1)/labelSum
max_iter = 40
u = torch.zeros_like(a)
v = torch.zeros_like(b)
dist_mat_norm= dist_mat/dist_mat.max()
for i in range(max_iter):
u = eps * (torch.log(a) - (torch.logsumexp(Mf(dist_mat_norm, u, v, eps), dim=-1))) + u
v = eps * (torch.log(b) - (torch.logsumexp(torch.transpose(Mf(dist_mat_norm, u, v, eps), 1, 0), dim=-1))) + v
pi = torch.exp((-dist_mat_norm + u.unsqueeze(-1) + v.unsqueeze(-2))/eps)
loss = torch.sum(pi * dist_mat)
return loss
The second method implementation (lossFunction_sinkhorn
) is provided below:
def lossFunction_sinkhorn(samples, labels, dist_mat, eps):
'''
samples is what is predicted by the network
labels is the target
'''
sampleSum=torch.sum(samples)
labelSum=torch.sum(labels)
a = samples.view(-1)/sampleSum
b = labels.view(-1)/labelSum
max_iter = 1000
u = torch.ones_like(a)
v = torch.ones_like(b)
dist_mat_norm= dist_mat/dist_mat.max()
M = torch.exp(-dist_mat_norm/eps)
for i in range(max_iter):
u = torch.div(a, (torch.mv(M,v) ))
v = torch.div(b, (torch.mv(torch.transpose(M, 0, 1), u) ))
pi = torch.mm(torch.mm(torch.diag_embed(v),M), torch.diag_embed(u))
loss = torch.sum(torch.transpose(pi, 1, 0)*dist_mat)
return loss
I notice though that while the former consumes ~8 GBs of memory on the GPU, the latter only requires <1 GB of GPU memory. Most of the code between the two methods is similar, except for the for loop
, where the number of iterations becomes the bottleneck for the former. I am guessing that a copy of the computation graph is produced and retained for each loop iteration: would be grateful for any intuition on why the former implementation is consuming so much GPU memory! Thank you in advance.
The former method makes a call to another method (Mf
) which is provided here below:
def Mf(C, u, v, epsilon):
"Modified cost for logarithmic updates"
"$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / epsilon