[Solved] Multiple PackedSequence input ordering

the second output of sort is the indices i the original ordering.

You can use these indices and a scatter_ operation to unsort to the original permutation.

x = torch.randn(10)
y, ind = torch.sort(x, 0)
unsorted = y.new(*y.size())
unsorted.scatter_(0, ind, y)
print((x - unsorted).abs().max())
14 Likes

Thank you for the sample code. It helps to unsort the indices. However, how could you do something similar for the hidden states to unsort them in the initial input order ? I have some trouble because it’s a 3D tensor and ind is only 1D.

Thank you !

EDIT: A way without scatter would be How to properly unsort unpacked sequences? but I’m not sure that the gradient is well propagated and is not very efficient. Can someone confirm this ?

any idea if its a feature already ?
thanks in advance !

Can you post snippet of your code?

1 Like
lengths = torch.tensor([len(indices) for indices in indices_list], dtype=torch.long, device=device)
lengths_sorted, sorted_idx = lengths.sort(descending=True)

indices_padded = pad_lists(indices, padding_idx, dtype=torch.long, device=device) # custom function
indices_sorted = indices_padded[sorted_idx]

embeddings_padded = self.embedding(indices_sorted)
embeddings_packed = pack_padded_sequence(embeddings_padded, lengths_sorted.tolist(), batch_first=True)

h, (h_n, _) = self.lstm(embeddings_packed)

h, _ = pad_packed_sequence(h, batch_first=True, padding_value=padding_idx)

# Reverses sorting. 
h = torch.zeros_like(h).scatter_(0, sorted_idx.unsqueeze(1).unsqueeze(1).expand(-1, h.shape[1], h.shape[2]), h)

This should help.

4 Likes

Thanks! I’ve been trying to figure out how to reverse the sorting and this is the best solution so far.

Isn’t that a good idea to just sort the labels(classes) after sorting the data instead of reversing the order.

Actually they are the same.

Do we still need to sort the batch by decreasing sequence length before pack_padded_sequence or has it been improved with recently ?

@Diego999 I think so. In pytorch 1.1 you don’t need to sort the sequence.