Learnable hidden state plus truncated back-propagation through time

Hello everyone,

I’m currently examining the potential of learning a hidden state initialization (instead of zeros) for a recurrent neural net (i.e. GRU/LSTM). However, I would like to combine the learnable initialization with truncated back-propagation through time.

Starting from this simple Time Series Perdiction Problem I gave it a try and implemented a learnable hidden state initialization given truncated bptt. My adjustments can be found in this colab notebook.

Conceptually, I have to make use of nn.Parameter(torch.zeros(hidden_state_size), requires_grad=True) to make it learnable. This is pretty simple if no truncation is used. Once chunks of sequences are fed, while feeding the latest hidden state make things more complicated. I’m wondering now, whether my code is conceptually correct. Could anybody take a brief look at it or could suggest an approach to assess it (e.g. visualizing the graph, plotting gradients)?

In the end, I’d like to add it to a Deep Reinforcement Learning algorithm that does use truncated bptt. Collecting the hidden states and their gradients inside the replay buffer will be another challenge. But that’s not worth to tackle if I’m uncertain concerning the learnable hidden state initialization.

Big thanks in advance to any opinions and advice on this matter. I hope that this thread will be helpful to others as well.

From a theoretical point of view, achieving a learnable hidden state initialization while doing truncated BPTT is not valid, because gradients are truncated and thus cannot be backpropagated to the original initialization.