Optimization tips for nested RNN


I’m implementing a model that requires the execution of two RNN “decoders” which operate on variable length sequences. In particular, I have an input tensor of shape (batch_size, time_1, time_2, emb_dim). I need to basically run the second RNN decoder for every tensor that can be extracted by slicing the second dimension (index 1) having tensors of size (batch_size, time_2, emb_dim). In the following I’ll provide a simplified version of my implementation:

class DoubleDecoder(nn.Module):
      def __init__(self):
           super(DoubleDecoder, self).__init__()
           self.decoder1_cell = torch.nn.GRUCell()
           self.decoder2 = DecoderModule() # implementation of a seq2seq decoder

     def forward(input):
          batch_size, time1, time2, emb_dim = input.size()
          acc_outputs_2 = []
          acc_outputs_1 = []
          # initialize the state of the decoder_1
          decoder1_state = init_state()

          for i in range(time_1):
               curr_input = input[:, i, :, :]
               curr_output_2 = decoder2(curr_input)
               decoder_1_state = decoder_1(curr_outputs_2, decoder1_state)

          return acc_outputs_1, acc_outputs_2

where decoder1 is an RNN cell (I’m using GRU but it doesn’t really matter in this case) and decoder2 is a module the implements a classical attentive seq2seq decoder.

I am concerned about the performance of my model. Both decoder1 and decoder2 need to be unrolled for a variable number of steps so essentially it’s just like doing two nested for loops. I guess that doing this in PyTorch has a really high complexity. In addition, the decoder1 needs to work on slices of the input along a specific dimension. Is there a way to replace the for-loop with something more efficient? I’ve seen other libraries (OpenMT-py) doing the same for seq2seq models so I assumed it was the right choice.

Looking forward to your suggestions!