RNNs Sorting operations autograd safe?

I all,

I’m coding an application where I want to implement something like this:

  1. For one unique RNN
    Sentence1 --> RNN --> output1
    Sentence2 --> RNN --> output2
  2. Concat of both outputs --> classifier --> unique final output

Both sentences have different lengths. When using variable length inputs, LSTMs in pytorch require the batch to be sorted by length of the sequence. Due to the fact that sentence1 and sentence2 have different sequence lengths, I need to reorder the batch first for sentence1 order and after that for sentence2 length order, feeded to the RNN and after that reorder again with the initial batch ordering to feed to the classifier the correct combination of sentence1 and sentence2.

My concern is if those operations of reordering can affect to the calculation of the autograd gradients or are safe.

Here below I put the code of the forward function that I’m using:

def sort_batch(data, seq_len):                                                  
    """ Sort the data (B, T, D) and sequence lengths                            
    """                                                                         
    sorted_seq_len, sorted_idx = seq_len.sort(0, descending=True)               
    sorted_data = data[sorted_idx]                                              
    return sorted_data, sorted_seq_len, sorted_idx                              
                                                                            
                                
class QuestCompNet(nn.Module):                                                  
      def __init__(self, feature_sz, hidden_sz, num_layers):                      
          super(QuestCompNet, self).__init__()                                    
          self.rnn = nn.LSTM(input_size=feature_sz,                               
                     hidden_size=hidden_sz,                               
                     num_layers=num_layers)                               
          self.l1 = nn.Linear(2*hidden_sz, hidden_sz)                             
          self.l2 = nn.Linear(hidden_sz, 1)                                       
          self.hidden_size = hidden_sz                                            
          self.num_layers = num_layers  
     def forward(self, x1, x2, seq_len1, seq_len2):                              
          batch_size = x1.size(0)                                                 
          assert (x1.size(0) == x2.size(0))                                                                 
          # init states of LSTMs                                                  
          h1, c1 = self.init_LSTM(batch_size)                                     
          h2, c2 = self.init_LSTM(batch_size)                                     
                                                                            
          # sort the batch                                                        
          x1_s, seq_len1_s, initial_idx1 = sort_batch(x1, seq_len1)               
          x2_s, seq_len2_s, initial_idx2 = sort_batch(x2, seq_len2)               
                                                                            
          # pack the batch                                                        
          x1_s = pack_padded_sequence(x1_s, list(seq_len1_s), batch_first=True)   
          x2_s = pack_padded_sequence(x2_s, list(seq_len2_s), batch_first=True) 
          # calculate forward pass of the two questions                           
          out1, h1 = self.rnn(x1_s, (h1, c1))                                     
          out2, h2 = self.rnn(x2_s, (h2, c2))                                     
                                                                            
          # unpack output                                                         
          out1, _ = pad_packed_sequence(out1, batch_first=True)                   
          out2, _ = pad_packed_sequence(out2, batch_first=True)                   
                                                                            
          # Index of the last output for each sequence.                           
          idx1 = (seq_len1_s-1).view(-1,1).expand(out1.size(0), out1.size(2)).unsqueeze(1)                                                                       
          idx2 = (seq_len2_s-1).view(-1,1).expand(out2.size(0), out2.size(2)).unsqueeze(1)                                                                       
                                                                            
          # last output of every sequence                                         
          last1 = out1.gather(1, Variable(idx1)).squeeze()                        
          last2 = out2.gather(1, Variable(idx2)).squeeze()                        
                                                                            
          # restore initial ordering                                              
          last1 = last1[initial_idx1]                                             
          last2 = last2[initial_idx2]                                             
                                                                          
          # prepare input for the classification layer                            
          class_layer = torch.cat((last1, last2), 1)                              
          class_layer2 = F.relu(self.l1(class_layer))                             
          
          return F.sigmoid(self.l2(class_layer2)).squeeze()                       
                                                                            
      def init_LSTM(self, batch_size=16):                                         
           # we get a pointer to parameters to get the type afterwards             
           weight = next(self.parameters()).data                                   
           # weight.new construct a new Tensor of the same data type and with      
           # the dimensions indicated                                              
           hidden_state = Variable(weight.new(self.num_layers, batch_size, self.hidden_size).zero_()).cuda()                                                      
           cell_state = Variable(weight.new(self.num_layers, batch_size, self.hidden_size).zero_()).cuda()                                                        
                                                                            
           return hidden_state, cell_state  

I call the forward function in this way:

    X1_batch, X2_batch, seq_len1, seq_len2, y_batch, _ = ds.sample(batch_size)                                                                             
    X1_batch = Variable(X1_batch).cuda()                                    
    X2_batch = Variable(X2_batch).cuda()                                    
    y_batch = Variable(y_batch).cuda()                                      
    seq_len1 = seq_len1.cuda()                                              
    seq_len2 = seq_len2.cuda()  

Thanks in advance!

Jordi

5 Likes

I’m having basically the same problem here. Did you figure out what to do?

In my model, I also met this problem, and I try to restore the original order after the sequence being processed by RNN.
I am not sure if it is right. Please point out the problems it has.

def forward(self, input, lengths, hidden):
    # Sort the input and lengths as the descending order
    lengths, perm_index = lengths.sort(0, descending=True)
    input = input[perm_index]

    packed_input = pack(input, list(lengths.data), batch_first=True)
    output, hidden = self.rnn(packed_input, hidden)
    output = unpack(output, batch_first=True)[0]

    # restore the sorting
    odx = perm_index.view(-1, 1).unsqueeze(1).expand(output.size(0), output.size(1), output.size(2))
    decoded = output.gather(0, odx)
    return decoded, hidden

Hey !

Please let me know if someone has figured out this issue ?

Thanks,
Shashank

The above code doesn’t restore initial ordering. Use torch.Tensor.scatter_ instead.

I guess using the gather makes the same mistakes as @Jordi_de_la_Torre’s code.
I think the correct way is to using the .scatter_ as suggested by @abhishek0318.
Or still using the indexing method as following:

    sorted_lengths, sorted_id = lengths.sort(0, descending=True)
    print sorted_lengths

    sorted_sorted_id, initial_id = sorted_id.sort(0, descending=False)

    sorted_input = input[sorted_id]

    print sorted_input
    print sorted_input[initial_id]

I wrote a toy program to check the correctness of this usage.
About the autograd safety, I think the permuting the tensor index is ok and won’t cause any
gradient problem.