GRU training very slow with sequence packing

I am developing a bidirectional GRU model with two layers for a sequence classification task. I am using the Google Colab environment with a NVIDIA T4 GPU. As I am training in batches and my sequence lengths vary, I am padding the sequences to equal lengths within each batch by using pad_sequences. I noticed that my time to train increases from ~5 minutes to ~30 minutes, if I pack the sequences in the model forward pass by using pack_padded_sequence. Alternatively I can just feed the model the padded sequences without packing, which results in (small) performance decrease.

What could be the reason behind this? This is my classifier model:

class GRU(nn.Module):
      def __init__(self, n_features, n_classes, hidden_size, num_layers, bidirectional):
          super(GRU, self).__init__()
  
          self.gru = nn.GRU(input_size= n_features, hidden_size= hidden_size,
                              batch_first= True, bidirectional= bidirectional, num_layers=num_layers)
  
          self.out = nn.Linear(in_features= hidden_size*2 if bidirectional else hidden_size, out_features= n_classes)
  
      def forward(self, x, lengths):
  
          packed = pack_padded_sequence(x, lengths, batch_first=True).cuda()
          x, hidden = self.gru(packed)
          x, output_lengths = pad_packed_sequence(x, batch_first=True)
          x = self.out(x)
          # Alternative, faster way:
          # x, hidden = self.gru(x)
          # x = self.out(x)
          return x
  
      def reset(self):
        for layer in self.children():
          if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

This is my collate_fn:

def collate(batch):
  """
    Makes a mini-batch of sleep recordings by padding them to equal length.

    Args:
      A list of SleepRecording instances

    Returns:
      features_padded: A padded feature tensor of shape (batch_size, rec_len, n_features)
      labels_padded: A padded label tensor of shape (batch_size, rec_len)
      lengths: A list of original lengths of the sleep recordings within in the batch

      Where rec_len is the length of the longest sleep recording within the batch.
  """
  # Get the length of each sleep recording in the batch
  lengths = [len(recording.features) for recording in batch]

  # Sort the recordings based on sequence length (in descending order)
  sorted_indices = sorted(range(len(lengths)), key=lengths.__getitem__, reverse=True)
  batch = [batch[i] for i in sorted_indices]
  
  features = [torch.tensor(recording.features) for recording in batch]

  # Shift classes from scale [1, C] to [0, C-1] as required by CrossEntropyLoss
  labels = [torch.tensor(recording.labels - 1) for recording in batch]

  # Make the recordings in the batch equal length by padding shorter recordings
  # to the length of the longest sleep recording.
  features_padded = pad_sequence(features, batch_first=True, padding_value= PADDING_VALUE)
  labels_padded = pad_sequence(labels, batch_first=True, padding_value= PADDING_VALUE)

  return features_padded, labels_padded, sorted(lengths, reverse=True)

Part of my training loop:

trainloader = DataLoader(train_dataset, collate_fn=collate, batch_size=batch_size, shuffle=True)
for epoch in range(max_epochs):

        model.train()
        for features, labels, lengths in trainloader:

          features = features.to(device).type(torch.float)
          labels = labels.to(device).type(torch.long)

          optimizer.zero_grad()

          output = model(features, lengths).reshape(-1, NUM_CLASSES)
          loss = criterion(output, labels.flatten())

          loss.backward()
          optimizer.step()

I also tried the profiling tools and got this output for the approach where I pack the sequences:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          model_forward         0.66%     526.000us        99.64%      79.596ms      79.596ms       0.000us         0.00%      51.900ms      51.900ms             1  
                                              aten::gru         0.11%      90.000us        94.75%      75.685ms      75.685ms       0.000us         0.00%      51.827ms      51.827ms             1  
                                       aten::_cudnn_rnn        42.56%      34.000ms        94.61%      75.573ms      75.573ms      51.827ms        99.55%      51.827ms      51.827ms             1  
void elemWiseRNNcell<float, float, float, (cudnnRNNM...         0.00%       0.000us         0.00%       0.000us       0.000us      23.953ms        46.01%      23.953ms       5.070us          4724  
void gemmSN_TN_kernel<float, 128, 16, 2, 4, 4, 4, fa...         0.00%       0.000us         0.00%       0.000us       0.000us      12.744ms        24.48%      12.744ms       6.000us          2124  
void gemmSN_TN_kernel<float, 128, 16, 2, 4, 4, 4, tr...         0.00%       0.000us         0.00%       0.000us       0.000us      10.152ms        19.50%      10.152ms       6.000us          1692  
std::enable_if<!(false), void>::type internal::gemvx...         0.00%       0.000us         0.00%       0.000us       0.000us       2.935ms         5.64%       2.935ms       5.026us           584  
void gemmSN_TN_kernel<float, 128, 16, 2, 4, 2, 2, tr...         0.00%       0.000us         0.00%       0.000us       0.000us       1.944ms         3.73%       1.944ms       6.000us           324  
                                           compute_loss         0.18%     141.000us         0.33%     267.000us     267.000us       0.000us         0.00%     159.000us     159.000us             1  
                               aten::cross_entropy_loss         0.01%      10.000us         0.15%     117.000us     117.000us       0.000us         0.00%     159.000us     159.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

and this output when I am not packing the sequences:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          model_forward         3.48%     196.000us        57.85%       3.261ms       3.261ms       0.000us         0.00%       2.975ms       2.975ms             1  
                                              aten::gru         1.06%      60.000us        28.95%       1.632ms       1.632ms       0.000us         0.00%       2.953ms       2.953ms             1  
                                       aten::_cudnn_rnn        24.57%       1.385ms        27.53%       1.552ms       1.552ms       2.949ms        96.22%       2.953ms       2.953ms             1  
void RNN_blockPersist_fp_GRU<float, float, float, 32...         0.00%       0.000us         0.00%       0.000us       0.000us       2.898ms        94.55%       2.898ms     724.500us             4  
                                           compute_loss         1.21%      68.000us         2.64%     149.000us     149.000us       0.000us         0.00%      90.000us      90.000us             1  
                               aten::cross_entropy_loss         0.09%       5.000us         1.33%      75.000us      75.000us       0.000us         0.00%      90.000us      90.000us             1  
                                      aten::nll_loss_nd         0.04%       2.000us         0.55%      31.000us      31.000us       0.000us         0.00%      87.000us      87.000us             1  
                                         aten::nll_loss         0.04%       2.000us         0.51%      29.000us      29.000us       0.000us         0.00%      87.000us      87.000us             1  
                                 aten::nll_loss_forward         0.34%      19.000us         0.48%      27.000us      27.000us      87.000us         2.84%      87.000us      87.000us             1  
void at::native::(anonymous namespace)::nll_loss_for...         0.00%       0.000us         0.00%       0.000us       0.000us      87.000us         2.84%      87.000us      87.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
1 Like