Hi all,
I customize a Loss module for seq2seq model. Since inputs are padded to max_seq_len for batch training, so I need a loss that can pass in mask
in order to ignore PAD
loss.
import torch.nn as nn
import torch
from torch.autograd import Variable
def _assert_no_grad(variable):
assert not variable.requires_grad, \
"nn criterions don't compute the gradient w.r.t. targets - please " \
"mark these variables as volatile or not requiring gradients"
class SequenceLoss(nn.Module):
def __init__(self, average_cross_timesteps=True, average_cross_batch=True):
super(SequenceLoss, self).__init__()
self.average_cross_timesteps = average_cross_timesteps
self.average_cross_batch = average_cross_batch
self.loss_func = nn.CrossEntropyLoss(size_average=False)
def forward(self, input, target, weight):
batch_size = input.size()[0]
max_seq_len = input.size()[1]
_assert_no_grad(target)
_assert_no_grad(weight)
num_classes = input.size()[-1]
logits_flat = input.view(-1, num_classes)
targets_flat = target.view(-1)
logits_flat_list = torch.split(logits_flat, 1, dim=0)
targets_flat_list = torch.split(targets_flat, 1, dim=0)
tmp = [self.loss_func(logits_flat_list[i], targets_flat_list[i]).data for i in range(batch_size*max_seq_len)]
crossent = Variable(torch.cat(tmp, dim=0), requires_grad=True) #Variable(torch.Tensor(tmp))
crossent = crossent * weight.view(-1)
if self.average_cross_timesteps and self.average_cross_batch:
# char level ppl
crossent = torch.sum(crossent)
total_size = torch.sum(weight)
total_size += 1e-12 # to avoid division by 0 for all-0 weights
crossent = crossent / total_size
else:
sequence_length = input.size()[1]
crossent = crossent.view(batch_size, sequence_length)
if self.average_cross_timesteps and not self.average_cross_batch:
crossent = torch.sum(crossent, 1)
total_size = torch.sum(weight, 1)
total_size += 1e-12
crossent = crossent / total_size
if not self.average_cross_timesteps and self.average_cross_batch:
# return crossent = [time_steps]
crossent = torch.sum(crossent, 0)
total_size = torch.sum(weight, 0)
total_size += 1e-12
crossent = crossent / total_size
return crossent
And loss doesn’t go down in training process. I don’t know where the problem is. Could anybody give me some advice?