I customize a Loss module for seq2seq model. Since inputs are padded to max_seq_len for batch training, so I need a loss that can pass in
mask in order to ignore
import torch.nn as nn import torch from torch.autograd import Variable def _assert_no_grad(variable): assert not variable.requires_grad, \ "nn criterions don't compute the gradient w.r.t. targets - please " \ "mark these variables as volatile or not requiring gradients" class SequenceLoss(nn.Module): def __init__(self, average_cross_timesteps=True, average_cross_batch=True): super(SequenceLoss, self).__init__() self.average_cross_timesteps = average_cross_timesteps self.average_cross_batch = average_cross_batch self.loss_func = nn.CrossEntropyLoss(size_average=False) def forward(self, input, target, weight): batch_size = input.size() max_seq_len = input.size() _assert_no_grad(target) _assert_no_grad(weight) num_classes = input.size()[-1] logits_flat = input.view(-1, num_classes) targets_flat = target.view(-1) logits_flat_list = torch.split(logits_flat, 1, dim=0) targets_flat_list = torch.split(targets_flat, 1, dim=0) tmp = [self.loss_func(logits_flat_list[i], targets_flat_list[i]).data for i in range(batch_size*max_seq_len)] crossent = Variable(torch.cat(tmp, dim=0), requires_grad=True) #Variable(torch.Tensor(tmp)) crossent = crossent * weight.view(-1) if self.average_cross_timesteps and self.average_cross_batch: # char level ppl crossent = torch.sum(crossent) total_size = torch.sum(weight) total_size += 1e-12 # to avoid division by 0 for all-0 weights crossent = crossent / total_size else: sequence_length = input.size() crossent = crossent.view(batch_size, sequence_length) if self.average_cross_timesteps and not self.average_cross_batch: crossent = torch.sum(crossent, 1) total_size = torch.sum(weight, 1) total_size += 1e-12 crossent = crossent / total_size if not self.average_cross_timesteps and self.average_cross_batch: # return crossent = [time_steps] crossent = torch.sum(crossent, 0) total_size = torch.sum(weight, 0) total_size += 1e-12 crossent = crossent / total_size return crossent
And loss doesn’t go down in training process. I don’t know where the problem is. Could anybody give me some advice?