Hi!
I am working with video data (individual frames) and want to improve the performance using a Recurrent Unit (Convolutional LSTM, to be precise). My idea is to use information from previous timeframes [t-2, t-1] to improve the prediction of [t]
I read through this thread and got confused about the different ways to deal with recurrency in Pytorch.
What is the “correct” way implement the scenario given above?
- Let the timesteps flow through the LSTM and then detach it:
input.shape() #[bs, #channels, height, width]
output, hidden_state = model(input, hidden_state)
hidden_state = tuple(state.detach() for state in hidden_state)
# detach so loss.backward() can be called without any retain_graph error
inside the forward() method, the previous model predictions are saved and stacked with the current input:
input = combine_with_old_predictions() #[bs, 3, #channels, height, width] with [t-2, t-1, t] stacked in dim1
h, c = previous_hidden
output_inner = []
for t in range(num_timesteps): # num_timesteps=3
# feed the timesteps one by through the lstm cell
h, c = lstm_cells(input_tensor=input[:, t, :, :, :], cur_state=[h, c])
output_inner.append(h)
layer_output = torch.stack(output_inner, dim=1)
- only feed in 1 timestep but let gradient flow back until 2 timesteps:
detach_interval = 3
for i, batch in enumerate(data_loader):
input, ground_truth = batch
output, hidden_state = model(input, hidden_state)
# input stays the same in the forward method of the model and is not stacked
# with the previous 2 timesteps [t-2, t-1]
loss = criterion(output, ground_truth)
if i % detach_interval == 0:
loss.backward(retain_graph=False)
else:
loss.backward(retain_graph=True)
What is the “correct LSTM-way” to deal with this scenario?
Any help is much appreciated!
Cheers,
Sven