Implementing cross label dependency loss

Hi, I am trying to implement CLD loss as defined in this paper (i.e., equation 3) and have two questions:

  1. Can someone check if I did the implementation correctly (code attached blow)?
  2. In my implementation, I used the “for loop” to select each row and then process it following equation (3). Is there a way to get rid of this “for loop”?
def cdl(y_hat, y_true):
    implementation of CDL as defined in
    :param y_hat: model predictions, shape(batch, classes)
    :param y_true: labels (batch, classes)
    :return: loss
    loss = torch.zeros(y_true.size(0))
    for idx, (y, y_h) in enumerate(zip(y_true, y_hat)):
        y_z, y_o = (y == 0).nonzero(), y.nonzero()
        output = torch.exp(torch.sub(y_h[y_z], y_h[y_o][:, None])).sum()
        num_comparisons = y_z.size(0) * y_o.size(0)
        loss[idx] = output.div(num_comparisons)
    return loss.sum()

Thank you.

If y_true looks like [[1, 1, 0], [0, 1, 1]] and y_hat looks like [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], try the following “no for” implementation. (If your implementation is correct.)

def cdl(y_hat, y_true):
    pre_sub = (y_hat[:, :, None] - y_hat[:, None, :])
    pre_exp = torch.exp(pre_sub)
    y_z, y_o = y_true == 0, y_true != 0
    num_comparisons = y_z.sum(dim=1) * y_true.sum(dim=1)
    pre_exp = pre_exp / num_comparisons[:, None, None]
    mask = torch.logical_and(y_z[:, :, None], y_o[:, None, :])
    loss = pre_exp[mask].sum()
    return loss

PS: Since pre_sub and pre_exp calculate all pairs, if batch << classes, maybe “for implementation” is faster(I am not sure)?

Thanks for the help, I checked your code and it produced the same results as mine. In my case, batch >> classes for sure.