This is my train method for cross_entropy:
def train_crossentropy(train_iter, dev_iter, test_iter, model, args):
print('training...')
if args.cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
steps = 0
best_accuracy = 0
last_step = 0
model.train()
for epoch in range(1, args.epochs + 1):
for batch in train_iter:
feature, target = batch.text, batch.label
feature = feature.data.t()
target = target.data.sub(1)
if args.cuda:
feature, target = feature.cuda(), target.cuda()
optimizer.zero_grad()
logit = model(feature)
loss = F.cross_entropy(logit, target)
loss.backward()
optimizer.step()
This is my train method for kl-loss:
def train_soft(train_iter, dev_iter, test_iter, model, args, temperature=1):
print('training with train_soft...')
if args.cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
steps = 0
best_accuracy = 0
last_step = 0
model.train()
for epoch in range(1, args.epochs + 1):
for batch in train_iter:
feature, target = batch.text, batch.label
feature = feature.data.t()
target = target.data.sub(1)
############# Soft-label changes start ######################
target_un = target.unsqueeze(0)
target_t = target_un.permute(1, 0)
y_onehot = torch.FloatTensor(batch.batch_size, args.class_num)
y_onehot.zero_()
y_onehot.scatter_(1, target_t, 1)
if args.cuda:
feature, target, y_onehot = feature.cuda(), target.cuda(), y_onehot.cuda()
optimizer.zero_grad()
logit = model(feature)
logits_flat = logit.view(-1, logit.size(-1))
log_probs_flat = F.log_softmax(logits_flat / temperature, dim=1)
target_flat = F.softmax(y_onehot / temperature, dim=1)
loss = F.kl_div(log_probs_flat, target_flat)
loss.backward()
optimizer.step()
I’m using the same target values in both cases. Just that in the KL version, the target values are converted into a one-hot format.
I cross-entropy loss values (for a small batch) is of order 10, while kl-loss values (for the same batch) is of order 0.001.
This is leading to very slow gradient descent and hence worse performance (using kl-loss).
Is this expected? I had an intuition that the loss values should be similar under this setup.