Split input sequence based on model training loss

Hi,

When processing long input sequences in seq2seq model, RNN/ transformer type models often perform worse. I want to implement a remedy by splitting the input sequences into smaller sequences (divide and conquer). And I want to find the splitting position by the loss of the seq2seq task.

So I want:
Input -> splitting layer -> Output -> encoder -> decoder -> loss

e.g.
[[ [1,1,1], [2,2,2], [3,3,3], [4,4,4] ]] (1 x 4 x 3) [(batch x len x feature)] -> splitting layer -> [ [[1,1,1] , [2,2,2], [3,3,3]], [[4,4,4]] ] (2x3x3 after padding)

My idea is to use BiLSTM to find the splitting position, and optimize BiLSTM model by seq2seq model output, but I am having trouble understanding

  1. How the loss back-propagate to the splitting layer?
  2. How to reshape and pad after finding the split position?

Thanks!