I am working with video data (individual frames) and want to improve the performance using a Recurrent Unit (Convolutional LSTM, to be precise). My idea is to use information from previous timeframes [t-2, t-1] to improve the prediction of [t]
I read through this thread and got confused about the different ways to deal with recurrency in Pytorch.
What is the “correct” way implement the scenario given above?
- Let the timesteps flow through the LSTM and then detach it:
input.shape() #[bs, #channels, height, width] output, hidden_state = model(input, hidden_state) hidden_state = tuple(state.detach() for state in hidden_state) # detach so loss.backward() can be called without any retain_graph error
inside the forward() method, the previous model predictions are saved and stacked with the current input:
input = combine_with_old_predictions() #[bs, 3, #channels, height, width] with [t-2, t-1, t] stacked in dim1 h, c = previous_hidden output_inner =  for t in range(num_timesteps): # num_timesteps=3 # feed the timesteps one by through the lstm cell h, c = lstm_cells(input_tensor=input[:, t, :, :, :], cur_state=[h, c]) output_inner.append(h) layer_output = torch.stack(output_inner, dim=1)
- only feed in 1 timestep but let gradient flow back until 2 timesteps:
detach_interval = 3 for i, batch in enumerate(data_loader): input, ground_truth = batch output, hidden_state = model(input, hidden_state) # input stays the same in the forward method of the model and is not stacked # with the previous 2 timesteps [t-2, t-1] loss = criterion(output, ground_truth) if i % detach_interval == 0: loss.backward(retain_graph=False) else: loss.backward(retain_graph=True)
What is the “correct LSTM-way” to deal with this scenario?
Any help is much appreciated!