How can i compute seq2seq loss using mask?

I am working on image captioning task with PyTorch.
In seq2seq, padding is used to handle the variable-length sequence problems.
Additionally, mask is multiplied by the calculated loss (vector not scalar) so that the padding does not affect the loss.

In TensorFlow, i can do this as below.

# targets is an int64 tensor of shape (batch_size, padded_length) which contains word indices.    
# masks is a tensor of shape (batch_size, padded_length) which contains 0 or 1 (0 if pad otherwise 1).

outputs = decoder(...)  # unnormalized scores of shape (batch_size, padded_length, vocab_size) 
outputs = tf.reshape(outputs, (-1, vocab_size))
targets = tf.reshape(targets, (-1))
losses = tf.nn.sparse_softmax_cross_entropy_loss(outputs, targets)    # loss of shape (batch_size*padded_length)

masks = tf.reshape(masks, (-1))
loss = losses * masks    

In PyTorch, nn.CrossEntropyLoss() returns a scalar not tensor so that i can not multiply loss by masks.

criterion = nn.CrossEntropyLoss()
outputs = decoder(features, inputs)   # (batch_size, padded_length, vocab_size)
loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))  # this gives a scalar not tensor

How can i solve this problem?

3 Likes

A non-averaged cross-entropy loss is coming soon. Until then you can write your own using log_softmax and advanced indexing.
In addition, though I don’t think it helps you here, nn.LSTM now has support for variable-length sequences without including padding, meaning that sequence model results will not be affected by the influence of padding tokens even with bidirectional RNNs. There are utility functions provided for creating the packed array data structure (~TF’s TensorArray) needed for this.

1 Like

How can i use nn.LSTM with variable-length sequences and without padding?
In code below, torch.utils.data.DataLoader concatenates each single tensor to construct mini batch data. This makes me pad the each sequence to make the tensor of fixed size.

cap = CocoCaptions(root = './data/train2014resized',
                   annFile = './data/annotations/captions_train2014.json',
                   vocab = vocab,
                   transform=transform,
                   target_transform=transforms.ToTensor())

data_loader = torch.utils.data.DataLoader(
    cap, batch_size=16, shuffle=True, num_workers=2)

for i, (images, input_seqs, target_seqs, masks) in enumerate(data_loader):
    # images: a tensor of shape (batch_size, 3, 256, 256).
    # input_seqs, target_seqs, masks: tensors of shape (batch_size, padded_length).
1 Like

If padded is a Variable with padding in it and lengths is a tensor containing the length of each sequence, then this is how to run a (potentially bidirectional) LSTM over the sequences in a way that doesn’t include padding, then pad the result in order to use it in further computations.

import torch.nn.utils.rnn as rnn_utils  
lstm = nn.LSTM(in_size, hidden_size, bidirectional, num_layers)
packed = rnn_utils.pack_padded_sequence(padded, lengths)
packed_out, packed_hidden = lstm(packed)
unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed_out)
2 Likes

Like https://github.com/Element-Research/rnn has MaskZero module, won’t PyTorch also need a wrapper to deal with padded inputs and gradients between LSTM outputs and final layer? Is there any plan for that or is there other elegant way to deal with it?

@supakjk there’s no need for any wrapper. As long as the padded entries don’t contribute to your loss in any way, their gradient will always be 0.

1 Like

@jekbradbury Could you elaborate how log_softmax and advanced indexing can be used for cross_entropy considering masks?

The cross-entropy loss for a particular vector of output scores and a target index is the value at that index of the negative log softmax of the vector of scores, so you can run negative log softmax on the whole score tensor, pick out the values you want using gather (advanced indexing was briefly semi-supported for this, and will be fully supported eventually), then sum/average the results

Could you let me know how to deal with the following example case?

Assume that I have an input, where maximum sequence length is 5, minibatch size is 4, and there are 3 possible labels. idxs has effective lengths of sequences in the input.

input = Variable(torch.randn(5,4,3))
idxs = [5,3,3,2]
target = Variable(torch.LongTensor(5,4))
# assume the target labels are assigned corresponding to the input.

Then, In a sequence tagging task, I’d like to get the cross entropy errors for the whole time steps of the first sequence, 3 time steps of the second sequence, and so on considering the values of idxs.
How could I use advanced indexing for addressing this sequence tagging with variable lengthed input sequences?

(I thought masked_select might be used for my purpose but I wonder what would be the most elegant one at this moment before some other features are added.)

Thanks!

The pack_padded_sequence can be placed at any point in the DAG right? That is, given a sequence, apply some dense layer on it and then pack it and give it to the LSTM

@pranav it’s packed in an additional structure to hold the sequence lengths, but you might take out its .data attribute, compute a function on that, and rewrap it in a new torch.nn.utils.rnn.PackedSequence object.

2 Likes

Then the non-averaged cross-entropy loss hasn’t come yet, but this could be easily implemented by torch.gather.
For example, suppose

  1. the outputs is the logit for each label, whose shape is [batch_size, time_step, num_token](if not, you could always get them with F.log_softmax)
  2. targets contains the targeted label for each step, whose shape is [batch_size, time_step]

Then the cross_entropy loss could calculated with

def nll_loss(outputs, targets):
    return torch.gather(outputs, 1, targets.view(-1,1))

Then you can apply masks as is in Tensorflow.

Can you help me understand this torch.gather I have used the same tensors with similar shape but this line torch.gather(outputs, 1, targets.view(-1,1)) gives me error.

Thanks for sharing the code. I test the nll_loss function and get:

RuntimeError: Input tensor must have same dimensions as output tensor at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:424

It seems that the input and index arguments of gather function should have same dimensions.

currently I use this way to compute masked loss for seq2seq model:

'''
target_sentences: [batch_size, max_length] int64
batch_target_len_array: [batch_size] int64, each element is the valid length of the corresponding target sentence. 
'''
for step in range(max_length):
    decoder_output, decoder_hidden = decoder(
        decoder_input, decoder_hidden, encoder_outputs)
    if np.sum(batch_target_len_array > step) == 0:
        break
    mask_vector = torch.from_numpy((batch_target_len_array > step).astype(np.int32)).byte()
    index_vector = Variable(torch.masked_select(torch.arange(0, batch_size), mask_vector).long()).cuda()
    valid_output = torch.index_select(decoder_output, 0, index_vector)
    valid_target = torch.index_select(target_sentences[:, step], 0, index_vector)

    step_loss = criterion(valid_output, valid_target)
    loss += step_loss
    decoder_input = target_sentences[:, step]  # Teacher forcing
4 Likes

In version 0.2.0 they added ignore_index in CrossEntropyLoss. I think this addition solves the problem. I use it like this:

loss_function = nn.CrossEntropyLoss(ignore_index=0)

since i have zero padded targets.

I used to use this masked_cross_entropy workaround, but i tested them both and i observe the same behaviour (compared to CrossEntropyLoss without ignore_index=0).

10 Likes

I have tried, doesn’t work in my case. Could you post your code with a bit more details?

It is clear. I’ll try! The newest version 0.3 updated the cost function with a new Parameter “reduce”, and it’s easy to mask the loss

Can you explain how to mask the loss now?

With the parameter “reduce”, we can get the loss per batch element, but how can we use the mask on it?

For example, if I have a minibatch whose valid sequence lengths is [3,1,2], and with “reduce” we can get three loss values [L1, L2, L3]. But we need to mask the last two values in calculating L2 and last one value in calculating L3. How could it be achieved?

------------------------------edited----------------------------

Oh I see, we only need to use ignore_index=0 parameter in loss(), but don’t need the “reduce” parameter to achieve this. Is it right?