I am trying to train an LSTM network in “stateful” mode, i.e. initializing a mini-batch’s hidden/cell state with the previous mini-batch’s last hidden/cell state. However, each of my mini-batches have sequences with padding at the ends. This is my problem: I want to avoid initializing the next mini-batch’s hidden/cell state with hidden/cell states from padded time-steps.
I am able to achieve this with packing the padded sequence, like so:
class LayerNorm_LSTM_Stateful(nn.Module):
def __init__(self, input_dim, output_dim):
super(LayerNorm_LSTM_Stateful, self).__init__()
self.layernorm = nn.LayerNorm(input_dim)
self.lstm = nn.LSTM(input_dim, output_dim, batch_first=True, bidirectional=False)
def forward(self, x, lengths=None, hidden=None):
x = self.layernorm(x)
if hidden is None or lengths is None:
x, hidden = self.lstm(x, hidden)
else:
x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
x, hidden = self.lstm(x, hidden)
x, _ = pad_packed_sequence(x, batch_first=True)
return x, hidden
However, with this packing scheme, my training time is almost 1.5x longer than without pack_padded_sequence
. Given the auto-regressive nature of RNNs, I am hoping for a more efficient way to train — perhaps indexing the desired timestep’s hidden/cell state?
Does anyone have any ideas? Thank you.