Hi,
I am trying to implement a masked KL divergence loss with label smoothing. Below is the code. Can anyone attest if this is the right way to go about it? Thanks!
Context: Calculating loss in Sequence2Sequence architecture:
Decoder output format: Batch x Sequence Length x Vocab size
Input format: Batch x Sequence Length (Note: this is not one hot encoded)
class KLDivergenceLossWithMask(nn.Module):
def __init__(self, label_smoothing, target_vocab, ignore_value=102):
super().__init__()
self.ignore_value = ignore_value
self.label_smoothing = label_smoothing
self.confidence = 1.0 - label_smoothing
self.target_vocab = target_vocab
smoothing_value = self.label_smoothing / (self.target_vocab - 2)
self.one_hot = torch.full((self.target_vocab,), smoothing_value)
self.one_hot = self.one_hot.unsqueeze(0)
def forward(self, inp, tgt):
model_prob = self.one_hot.repeat(tgt.size(0),tgt.size(1), 1)
model_prob = model_prob.scatter_(2,tgt.unsqueeze(2), value=self.confidence)
mask = (tgt != self.ignore_value)*1
model_prob_masked = model_prob*mask.unsqueeze(2)
inp = (inp*mask.unsqueeze(2))
return nn.KLDivLoss(reduction = 'batchmean')(F.log_softmax(inp, dim=2),
model_prob_masked)
inp = torch.randn(4,7,12000).float() #batch,seq_len,vocab
tgt = torch.LongTensor([[23,34,22,67,10122,11343,102],
[9999,999,99,9,12,344,6678],
[1133,9434,2234,102,102,102,102],
[0,1,2,3,102,102,102]])#batch,seq_len
print(inp.shape, tgt.shape)
label_smoothing = 0.2
target_vocab = 12000
ignore_value = 102
kl_loss_with_mask = KLDivergenceLossWithMask(label_smoothing, target_vocab, ignore_value)
loss = kl_loss_with_mask(inp, tgt)
print(loss)