Truncated BPTT with seq2seq

Hi,

I’m currently working on an encoder-decoder architecture using soft hybrid attention. Everything is working fine but I am facing OOM issues when dealing with long sequences.
The best solution I found is to use Truncated Back Propagation Through Time (TBPTT) to reduce the memory consumption by splitting the sequence to produce.
So my question is, how to retain the computation graph for my encoder only, after the backward pass ?

More specifically, the architecture is made up of an encoder which extracts a sequence of features from the input, an attention module which computes attention weights for each feature frame at a given time step and an LSTM-based decoder that outputs class probabilities for this same time step.
This can be seen as a transformer architecture but the attention is hybrid so the training cannot be parallelized for the whole sequence with teacher forcing.

Let’s say I want to use TBPTT with k1=k2=k.
I tried to call backward() after k step, detaching all intermediate tensors, but then the computation graph for the encoder is freed and the encoder parameters will not be updated for the following steps.
And if I use backward(retain_graph=True), allocated memory is increasing at each step, also leading to OOM issues.
The aim is to process the sequence by sub-sequences of length k and update the weights of the whole network for each sub-sequence without recomputing the features (encoder part) each time.

Am I doing things right ? Is there alternatives to TBPTT in my case ?

To make things clearer, here is how the code looks like without TBPTT:

# x: input
# y: output sequence
# hidden: LSTM hidden state of the decoder module
loss = 0
features = EncoderModule(x)
for i in range(len(y)-1):
    att_weights, hidden, pred = AttentionDecoderModule(x, att_weights, hidden, y[i])
    loss += loss_func(pred, y[i+1]
self.backward_loss(loss)
self.step_optimizer()
self.optimizer.zero_grad(set_to_none=True)

And I would like something like:

# x: input
# y: output sequence
# hidden: LSTM hidden state of the decoder module
loss = 0
features = EncoderModule(x)
for i in range(len(y)-1):
    att_weights, hidden, pred = AttentionDecoderModule(x, att_weights, hidden, y[i])
    loss += loss_func(pred, y[i+1]
    if i % k == 0 or i == len(y)-2:
        self.backward_loss(loss)
        # keep computation graph for the encoder
        self.step_optimizer()
        self.optimizer.zero_grad(set_to_none=True)
        loss = 0
        hidden = [h.detach() for h in hidden]
        att_weights = att_weights.detach()