Hi, I am trying to implement CLD loss as defined in this paper (i.e., equation 3) and have two questions:
- Can someone check if I did the implementation correctly (code attached blow)?
- In my implementation, I used the “for loop” to select each row and then process it following equation (3). Is there a way to get rid of this “for loop”?
def cdl(y_hat, y_true):
"""
implementation of CDL as defined in https://arxiv.org/pdf/1707.00418.pdf
:param y_hat: model predictions, shape(batch, classes)
:param y_true: labels (batch, classes)
:return: loss
"""
loss = torch.zeros(y_true.size(0))
for idx, (y, y_h) in enumerate(zip(y_true, y_hat)):
y_z, y_o = (y == 0).nonzero(), y.nonzero()
output = torch.exp(torch.sub(y_h[y_z], y_h[y_o][:, None])).sum()
num_comparisons = y_z.size(0) * y_o.size(0)
loss[idx] = output.div(num_comparisons)
return loss.sum()
Thank you.