How to Gain More Efficient Stateful RNN Training with Padded Sequences?

I am trying to train an LSTM network in “stateful” mode, i.e. initializing a mini-batch’s hidden/cell state with the previous mini-batch’s last hidden/cell state. However, each of my mini-batches have sequences with padding at the ends. This is my problem: I want to avoid initializing the next mini-batch’s hidden/cell state with hidden/cell states from padded time-steps.

I am able to achieve this with packing the padded sequence, like so:

class LayerNorm_LSTM_Stateful(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LayerNorm_LSTM_Stateful, self).__init__()
        self.layernorm = nn.LayerNorm(input_dim)
        self.lstm = nn.LSTM(input_dim, output_dim, batch_first=True, bidirectional=False)
        
    def forward(self, x, lengths=None, hidden=None):
        x = self.layernorm(x)

        if hidden is None or lengths is None:
            x, hidden = self.lstm(x, hidden)
        else:
            x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
            x, hidden = self.lstm(x, hidden)
            x, _ = pad_packed_sequence(x, batch_first=True)
    
        return x, hidden

However, with this packing scheme, my training time is almost 1.5x longer than without pack_padded_sequence. Given the auto-regressive nature of RNNs, I am hoping for a more efficient way to train — perhaps indexing the desired timestep’s hidden/cell state?

Does anyone have any ideas? Thank you.

Relative performance depends both on tensor shapes and hardware/software versions… One thing to try is batch_first=False mode. If that doesn’t help, then yes, indexing is a valid approach (at least for GRUs, with LSTMs hidden/cell state distinction complicates things).

Some random snippet of indexing (table is (time,batch,*) , gather_pts is a 1d tensor (last indices)):

brange = torch.arange(batch_size, dtype=torch.int64, device=table.device)
y = table[gather_pts, brange]

or you can do the same with torch.gather

You can also try GitHub - lmnt-com/haste: Haste: a fast, simple, and open RNN library, I think it outputs correct states without packing.