Managing gradients on variable length sequences

What is the best way to train an LSTM on data with variable length sequences, where some of those sequences are too long to perform backpropogation in full?

My situation:

I am training an LSTM on many thousands of independent sequences. The individual sequences can vary in length from 2 to 10000, but most are short (<100). I am using pad_packed_sequence() and pack_padded_sequence to process the data in minibatches. The long sequences (circa 10000 long) are causing issues when I perform full backpropogation, so I am looking for smarter alternatives.

What I have tried so far:

First, I tried splitting the sequences into chunks of 128 and feeding them through separately. This works ok, but essentially resets the hidden and cell state after every 128 steps, resulting in lost information that should be carried forward. The training loop looks something like this:

for epoch in n_epochs:
    for batch in data_loader:
        input, target = batch
        output = network.forward(input)
        loss = loss_function(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

To stop the loss of information, I tried processing each sequence in one go but performing backpropogration after every 128 steps. To do this, I manually manage the hidden and cell states, detaching them after each forward/backward pass. I also drop sequences from the minibatch once they have been completely processed. This approach also works but is dreadfully slow, possibly because the minibatch is shrinking as I drop samples. The training loops looks something like this:

for epoch in n_epochs:
    for batch in data_loader:
        (hidden_state, cell_state) = network.init_hidden()
        while(True):
            (input, target), (hidden_state, cell_state) = get_next_chunk_and_filter(batch, hidden_state, cell_state)
            if len(input) == 0:
                break
            output, (hidden_state, cell_state) = network.forward(input, (hidden_state, cell_state))
            loss = loss_function(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            (hidden_state, cell_state) = (hidden_state.detach(), cell_state.detach())

Other ideas:

I am now considering replacing samples in the minibatch once they have been processed, rather than removing them entirely. This seems like it would make the training more efficient, but would also require some very messy juggling of samples, hidden states and cell states. I think it would work, but isn’t very elegant.

Is there a better way to do this? This must be a really common problem?

Thanks!