"Correct" way deal with hidden state, .detach() and retain graph in Recurrent scenarios

I am working with video data (individual frames) and want to improve the performance using a Recurrent Unit (Convolutional LSTM, to be precise). My idea is to use information from previous timeframes [t-2, t-1] to improve the prediction of [t]

I read through this thread and got confused about the different ways to deal with recurrency in Pytorch.

What is the “correct” way implement the scenario given above?

  1. Let the timesteps flow through the LSTM and then detach it:
input.shape()  #[bs, #channels, height, width]

output, hidden_state = model(input, hidden_state)
hidden_state = tuple(state.detach() for state in hidden_state) 
# detach so loss.backward() can be called without any retain_graph error

inside the forward() method, the previous model predictions are saved and stacked with the current input:

input = combine_with_old_predictions() #[bs, 3, #channels, height, width] with [t-2, t-1, t] stacked in dim1
h, c = previous_hidden
output_inner = []
for t in range(num_timesteps): # num_timesteps=3
    # feed the timesteps one by through the lstm cell
    h, c = lstm_cells(input_tensor=input[:, t, :, :, :], cur_state=[h, c])
layer_output = torch.stack(output_inner, dim=1)
  1. only feed in 1 timestep but let gradient flow back until 2 timesteps:
detach_interval = 3
for i, batch in enumerate(data_loader):
    input, ground_truth = batch
    output, hidden_state = model(input, hidden_state)
    # input stays the same in the forward method of the model and is not stacked 
    # with the previous 2 timesteps [t-2, t-1]
    loss = criterion(output, ground_truth)
    if i % detach_interval == 0:

What is the “correct LSTM-way” to deal with this scenario?
Any help is much appreciated!