I have two models, an SI (speaker independent) model, which is already trained, and an SD (speaker dependent) model (to be learned). At the beginning, they are the same. I want to adapt SI to a new speaker by minimizing CTC loss on SD data, starting from SI model. But since I do not want to overfit, I add a weighted KLD loss to the CTC, to prevent the adapted model to go too far from the SI model. The weighting factor is self.mu in the loss. The loss code looks like this:
import torch
from warp_ctc_pytorch import CTCLoss
from torch.nn import KLDivLoss as KLD
from torch.nn.functional import F
class CTC_KLD(nn.Module):
def init(self, mu):
super(CTC_KLD, self).init()
self.mu = mu
self.ctc_loss = CTCLoss(length_average = True)
self.KLD = KLD(size_average = False)
def forward(self, SI_logits, SD_logits, SD_targets, SD_target_sizes, input_sizes, input_sizes_list):
SD_logits_ctc = torch.transpose(SD_logits, 0, 1).contiguous() # SD_logits_ctc size: T, N, D
CTC_loss = self.ctc_loss(SD_logits_ctc, SD_targets, input_sizes, SD_target_sizes).type(torch.cuda.FloatTensor) # ctc loss
SI_logits = rnn_utils.pack_padded_sequence(SI_logits, input_sizes_list, batch_first = True).data
SD_logits_KL = rnn_utils.pack_padded_sequence(SD_logits, input_sizes_list, batch_first = True).data
batch_size = SI_logits.size(0)
log_probs_SD = F.log_softmax(SD_logits_KL, dim = 1)
probs_SI = F.softmax(SI_logits, dim = 1)
KLD_loss = self.KLD(log_probs_SD, probs_SI) / batch_size
loss = (1.0 - self.mu) * CTC_loss + self.mu * KLD_loss
return loss
In the main script, for an input variable x of size N, T, D, which contains N sequences from the new speaker, it first goes through both SI and SD models, to obtain SD_logits and SI_logits, then I detach SI_logits from the graph using SI_logits = SI_logits.detach(), since the SI model should not be updated. It only provides the targets for the KLD loss. Then, I pass SI_logits and SD_logits through the loss function.
In the above loss code, if I wrote loss = CTC_loss, then the training works fine. But when I wrote loss = 1.0 * CTC_loss + 0.0 * KLD_loss (in which self.mu = 0), the result (measured by word error rate in speech recognition) becomes very different than simply writing loss = CTC_loss, but they should be the same loss function (with only CTC_loss). Anyone has any ideas why they differ a lot?