[Solved] Multiple PackedSequence input ordering

Hi !

I’m new on pytorch (moving from torch), and I’m having some problems to implement a model …

I’ve two variable length time-serie sequences that will be forwarded in the same network, and its output will be compared using the cosine distance.

The problem here is that PackedSequence expects the sequences to be ordered by seq length. However, my model expects two sequences as input, so if I order them, then first seq of the first input may not be related with the first seq of the second input.

What can I use to be able to compare the correct pair of sequence ?

7 Likes

This comes up fairly frequently (e.g. machine translation) and the current solution is to use torch.sort to order examples by length on each side and store the indices in order to reverse that order later. But it’s easy to get this wrong, and it should probably be added as a convenience function to the rnn utils. I may do that if I have time, or if anyone else wants to make a PR they should go ahead.

7 Likes

My forward function looks like:

   def forward(self, dict_index, features, prev_hidden, seq_sizes, original_index):
        i2e = self.embedding(dict_index)
        data = torch.cat((i2e, features), 2)

        packed = pack_padded_sequence(data, list(seq_sizes.data), batch_first=True)
        output, _ = self.rnn(packed, prev_hidden)

        output, _ = pad_packed_sequence(output, batch_first=True)
        # get the last time step for each sequence
        idx = (seq_sizes - 1).view(-1, 1).expand(output.size(0), output.size(2)).unsqueeze(1)
        decoded = output.gather(1, idx).squeeze()

        # restore the sorting
        decoded[original_index] = decoded

        return decoded

and then I compare the decoding using the cosine embedding. Will pytorch use the correct gradients on backward ? Do I need to modify something ? Because I’m changing the order of the data in the last step of the forward, so do I need to “reorder” the loss values ?

Hi,

I did this a few days ago,

# Helper functions

def argsort(seq):
    # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
    return sorted(range(len(seq)), key=seq.__getitem__)

# apply padding
def pad(tensor, length):
    return torch.cat([tensor, tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_()])

lengths=[]
for i in range(batch_size):
    length_i = input_variables[i].size()[0]
    lengths.append( length_i )

# sort in numpy - TODO do this in pure python or PyTorch
#lengths_np = np.array((lengths))
#indx = np.argsort(lengths_np, axis=0)[::-1]
#inverse_indx = np.argsort(indx, axis=0)

#sort - #see http://stackoverflow.com/questions/2483696/undo-or-reverse-argsort-python
indx = argsort( lengths ) # ascending
indx = indx[::-1]         # descending
inverse_indx = argsort(indx)

# check
#lengths_dec = [lengths[i] for i in indx] # sorted lengths
#remap_lengths = [lengths_dec[i] for i in inverse_indx] # map 
#lengths == remap_lengths

padded_vec_lst = []
for i in range(batch_size):
    padded_vec = pad(vec_lst[i], max_length)
    padded_vec_lst.append( padded_vec )

It works :slight_smile:

HTHs

5 Likes

Hi ! Thx for your code :wink:
Currently I’m ordering it using pytorch’s torch.sort function.

But I’m not sure if the gradients are correct, because my grad_output contains the original data ordering (I need to maintain the original order to compare the distance of two sequences), but each model receives sequences in a different ordering.

So I assume that I’ll need to reorder the grad_output (loss derivatives) to match the forward order.

But I’m not sure if I really need to do it, or if pytorch does it internally. And if necessary, I don’t know how to reorder the grad_output :frowning:

To update, I solved this problem.
I checked and yes, pytorch backward does the gradient reordering. I just updated the

decoded[original_index] = decoded

with gather.

Working now, thx everyone :wink:

4 Likes

@aron-bordin could you put a snippet of the final version? I am a bit confused.

@nitish I just updated my forward funcion above. The final forward is:

def forward(self, dict_index, features, prev_hidden, seq_sizes, original_index):
    i2e = self.embedding(dict_index)
    data = torch.cat((i2e, features), 2)

    packed = pack_padded_sequence(data, list(seq_sizes.data), batch_first=True)
    output, _ = self.rnn(packed, prev_hidden)

    output, _ = pad_packed_sequence(output, batch_first=True)
    # get the last time step for each sequence
    idx = (seq_sizes - 1).view(-1, 1).expand(output.size(0), output.size(2)).unsqueeze(1)
    decoded = output.gather(1, idx).squeeze()


    # restore the sorting
    odx = original_index.view(-1, 1).expand(args.batch_size, output.size(-1))
    decoded = decoded.gather(0, Variable(odx))
    return decoded

    return decoded
6 Likes

@aron-bordin So using this, if you have multiple mini-batches (with different sorting permutations) going in different RNNs, this forward pass returns them in the original order and then they are directly comparable. Pytorch takes care of this permutation while backpropagating. I hope I understand this correctly.
One question, how did you check that the grads are permuted correctly?

Thanks!

Yes, its correct.

I don’t have any sample here, but I tested implementing a simple equation, then I passed multiple values to this equation, change the order with gather, applied another function and the backpropagated. Then I calculated the derivatives manually and compared with the pytorch output, so I could verify pytorch was considering the ordering while backpropagating.

1 Like

@aron-bordin Hey, wouldn’t it be easier if you change this from:

to

output, (ht, ct) = self.rnn(packed, prev_hidden)
decoded = ht[-1]

Both the decoded above should be the same.

Isn’t the output of self.rnn already gave you the output for t=seq_len, even though we have variable-length input?

Also, may I ask what’s your prev_hidden? I thought it should be the initial state for hidden and cell.

1 Like

Both the decoded above should be the same.

It does look like the two are equivalent. Could anyone please confirm if using ht[-1] is appropriate?

No, they are not the same.
The problem here is that I’m passing a batch of variable sequence lengths. So. some of the sequences will be padded with 0’s vectors, leading to an unwanted output in the rnn.
So, the rnn will produce one output per timestep, and as long as some of the timesteps are zero (due to the padding), it’s necessary to get the proper output.

1 Like

My input is also a batch of variable-length input. I also padded them and pack them using pack_padded_sequence.

using ht[-1] return the same result as your solution.

I think you are expecting ht[-1] will return zeros for short inputs that were padded, right? But, that’s not the case when I tested it. Would you mind to double-check that? Or am I missing something here?

2 Likes

I’ll check this on Monday and I post here. But some points to consider is
that I’m using both LSTM and GRU rnn, and with the bidirectional
architecture. And I remember that I first tested with -1, but this was not
working in my case, and the I used the solution above. I’ll confirm it in
the beginning of the week :wink:

My prev_hidden is usually zero while training, I just have the parameter to
evaluate the model in some specific starting points.

Does anyone have sample code for batched multiple packed sequences with Attention?

Here’s where I am confused:

Background I have a data set with a series of sentence tuples (s1, s2). I have a Bidi lstm for each of them and I would like to attend on the two sequences s1, and s2 before I send them off to a linear layer – I would also like to do this in a batch setting though the pseudo code below is written in a per instance (i.e. per s1,s2 tuple) forward pass.

Something like this (not working code – pseudo code)

    combined_input_2_attention=torch.cat((s1_lstm_out, s2_lstm_out), 0)
	attention_alphas=self.softmax(self.attention_layer(combined_input_2_attention))
	attn_applied = torch.bmm(attention_alphas.unsqueeze(0),
                             combined_input_2_attention)
	
	output_embedding=self.hidden_layer(attn_applied)
	output=self.softmax(output_embedding)

where s1_lstm_out and s2_lstm_out are the outputs of sending one s1, s2 tuple in the forward pass.

Q1. If this was batched – Do the attention weights (alphas) need to be in dimensions of length of the max sequence length (s1+s2 since I am concatenating) per batch(?) or globally?. I doubt this is something can be per batch – because how would i initialize the dimensions of the linear layer that does the attention?

Q2 Either way I need to have a packed sequence for the two LSTMs before I attend over them or concatenate them – but the problem is padding requires them to be sorted individually.
I read the answer of using sort but I could not follow it completely.

Here’s what I understood from the post above:

  1. I take a set of tuples (s1, s2). Sort them individually using tensor.sort and keep track of their original indices.
  2. Individually pack s1 and s2 into padded sequences based on their individual max lengths per batch and send it to my forward pass
  3. At the end of it – I return my outputs by reverse mapping them based on the sorted order generated in (1)? If yes, wouldn’t the network have seen an instance of s1 and s2 in different orders from my original – why is this correct/why does this work? What am I missing?

Any leads /clarification would be extremely helpful! Thanks!

Figured this out –

sequence 1 - sort -> pad and pack ->process using RNN -> unpack ->unsort
sequence 2 - sort -> pad and pack ->process using RNN -> unpack ->unsort

Do whatever you wanted to do with the unsorted outputs (combine, attend-- whatever)

2 Likes

I think you are better off just by keeping it sorted throughout the training, as sorting operations are usually faster on the CPU (therefore it is better to do the sorting at the data loading stage).

2 Likes

@miguelvr Thanks - the reason I am unsorting is that both my sequences are “paired” but pass through different RNNs. If I dont unsort, I lose the pairing (because packing was done independently).

1 Like

@aron-bordin : I have a couple of questions. To confirm, dict_index is sorted and original_index is the inverse of the sorting permutation? Is that correct and how do you get the inverse of the sorting permutation from torch.sort?