How to properly unsort unpacked sequences?

UPDATE: ok, I think I found a solution, would doing it like this (see example minimal code below) be correct?

I am trying to use packed sequences for my model, but since I have more than one input with sequences, I cannot simply sort the outputs: the inputs may both require the batch to get sorted in different ways.

So the only solution I can see here is to “unsort” the unpacked sequences once I have sent them through the LSTM.

The LSTM returns (out,(hend,cend)) where out is a packed sequence. When I unpack this, I get my sequences sorted by length.

I still have my indices from sorting, but all methods I tried to use those for “unsorting” break my gradients chain or make PyTorch complain.
If I use the method based on newvar=sortedvar.gather(dim,unsort_idxs) then I get the error message “save_for_backward can only save input or output tensors, but argument 0 doesn’t satisfy this condition”.
UPDATE: I think I have found a much simpler solution now (see below) would this be the correct way to do it?

Here is some minimal code to illustrate what I am trying to do:

import torch
from torch.autograd import Variable as V
# make it easier to create the sequences of vectors
embs1=torch.nn.Embedding(100,3,padding_idx=0)
lstm1 = torch.nn.LSTM(3,5,1)  # lst takes 3-dimensional inputs
batch=V(torch.LongTensor([[1,2,0,0,0,0],[3,0,0,0,0,0],[2,4,5,2,3,1],[4,1,2,2,0,0]]))
e_batch = embs1(batch)   # get the batch of sequences of embeddings
# Note: the batch is of shape batchsize,maxseq,3 so we need to use batch_first later
# these are my sequence lengths
lens = [2,1,6,4]
lens_sorted,idx = torch.IntTensor(lens).sort(0, descending=True)
# sort the embeddings batch by lengths
e_batch_sorted = e_batch[idx]
# create the packed sequences
packed=torch.nn.utils.rnn.pack_padded_sequence(e_batch_sorted, lens_sorted.tolist(), batch_first=True)
# get the output from the lstm
(out,(hout,cout))=lstm1(packed)
(unpacked_out,_) = torch.nn.utils.rnn.pad_packed_sequence(out,batch_first=True)
# would like to unsort like so:
_,orig_idx = idx.sort(0)
# original, attempt which did not work
# unsort_idx = orig_idx.view(-1,1,1).expand_as(unpacked_out)
# unsorted = unpacked_out.gather(0,unsort_idx.long())

# UPDATE: this seems to do what I want:
unsorted = unpacked_out[orig_idx]

When I try to run this, the very last statement gives me the error: “save_for_backward can only save input or output tensors, but argument 0 doesn’t satisfy this condition”

1 Like