The magnitude of KL divergence loss is too low compared to Cross-entropy loss

This is my train method for cross_entropy:

def train_crossentropy(train_iter, dev_iter, test_iter, model, args):
    print('training...')
    if args.cuda:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    steps = 0
    best_accuracy = 0
    last_step = 0
    model.train()
    for epoch in range(1, args.epochs + 1):
        for batch in train_iter:
            feature, target = batch.text, batch.label

            feature = feature.data.t()
            target = target.data.sub(1)

            if args.cuda:
                feature, target = feature.cuda(), target.cuda()

            optimizer.zero_grad()
            logit = model(feature)

            loss = F.cross_entropy(logit, target)
            loss.backward()
            optimizer.step()

This is my train method for kl-loss:

def train_soft(train_iter, dev_iter, test_iter, model, args, temperature=1):
    print('training with train_soft...')
    if args.cuda:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    steps = 0
    best_accuracy = 0
    last_step = 0
    model.train()
    for epoch in range(1, args.epochs + 1):
        for batch in train_iter:
            feature, target = batch.text, batch.label

            feature = feature.data.t()
            target = target.data.sub(1)
            
            ############# Soft-label changes start ######################
            target_un = target.unsqueeze(0)
            target_t = target_un.permute(1, 0)
            
            y_onehot = torch.FloatTensor(batch.batch_size, args.class_num)
            y_onehot.zero_()
            y_onehot.scatter_(1, target_t, 1)
            
            
            if args.cuda:
                feature, target, y_onehot = feature.cuda(), target.cuda(), y_onehot.cuda()

            optimizer.zero_grad()
            logit = model(feature)

            logits_flat = logit.view(-1, logit.size(-1))
            log_probs_flat = F.log_softmax(logits_flat / temperature, dim=1)
            target_flat = F.softmax(y_onehot / temperature, dim=1)
            
            loss = F.kl_div(log_probs_flat, target_flat)
            loss.backward()
            
            optimizer.step()

I’m using the same target values in both cases. Just that in the KL version, the target values are converted into a one-hot format.
I cross-entropy loss values (for a small batch) is of order 10, while kl-loss values (for the same batch) is of order 0.001.
This is leading to very slow gradient descent and hence worse performance (using kl-loss).

Is this expected? I had an intuition that the loss values should be similar under this setup.

There are a couple of caveats in the notes coming with the KLDivLoss documentation.
Once you have these taken care of and use the correct input conventions, you should get the same loss and gradients for the things that ought to be mathematically equivalent (up to numerical error, of course).

I was able to exactly align the loss values for KL-divergence and cross-entropy. See the code below:

a = torch.randint(10, (2, 5))
b = torch.Tensor([[0., 0., 0., 1., 0.], [0., 0., 1., 0., 0.]])
target = torch.Tensor([3, 2]).type(torch.LongTensor)

print('a', a)
print('b', b)

a_log_soft = F.log_softmax(a.float(), dim=1)
print('a_log_soft', a_log_soft)
b_soft = F.softmax(b, dim=1)
print('b_soft', b_soft)

print('kl_loss batchmean', F.kl_div(a_log_soft, b, reduction='batchmean'))

print('kl_loss_prob batchmean', F.kl_div(a_log_soft, b_soft, reduction='batchmean'))

print('cross_loss mean', F.cross_entropy(a.float(), target))

here, kl_loss batchmean aligns perfectly with cross_loss mean. However, kl_loss_prob batchmean doesn’t align with cross_loss mean. However, in a real scenario if we have our b input as raw logits, kl_loss batchmean is the one that should be used.
Also, make sure to use reduction='batchmean'.

Hence, in my original question all I need to do is change:

loss = F.kl_div(log_probs_flat, target_flat)

to

loss = F.kl_div(log_probs_flat, y_onehot, reduction='batchmean')

and this loss will exactly match with the cross-entropy loss.

2 Likes