How to backpropagate a loss through time-series RNN?

I am trying to implement an RNN-based model for time-series data and any help would be much appreciated! I have a reward signal I would like to utilize to backpropagate a loss through the RNN every n steps. I cannot seem to find a way to backpropagate anything without detaching the hidden state, but I don’t think that is a good approach in this case.
Let’s start from a simple example, straight from the PyTorch documentation, but amended such that one can step through the optimizer within the time loop:

import torch
import torch.nn as nn

rnn = nn.GRUCell(10, 20)
optimizer = torch.optim.SGD(rnn.parameters(), lr=1e-3)
input = torch.randn(6, 3, 10)
hx = torch.randn(3, 20)
output = []
for i in range(6):
    hx = rnn(input[i], hx)
    output.append(hx)
    hx.mean().backward(retain_graph=retain_graph)
    optimizer.step()
    optimizer.zero_grad()

In the above script, if retain_graph is False, I get a “RuntimeError: Trying to backward through the graph a second time”. When retain_graph is True, I get “RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [20, 60]]”.
I’ve seen some time-series RNN implementations online that don’t seem to bother with detaching the hidden state, so I suppose this used to work well in previous PyTorch versions.

Hey,
To keep the states (cell and hidden) you only need to initialize the weights before the training phase or in the init method (since it is only called once).
If you need any further help I can help you since I am developing work in that field.
Regards
André

Hello,
Thank you for your quick reply but I don’t quite understand what you mean, sorry? Just to clarify, I want to utilize the updated hidden states for subsequent RNN steps while also being able to backpropagate through the network at various intervals (i.e. in an online learning fashion). So, I do initialize the states before the training commences (as the example above illustrates), but I want to be able to use the updated states for subsequent forward passes without reinitializing or detaching the states from the graph at each forward pass, if that makes sense. If you reckon you have a solution for the above, a quick code snippet would be really helpful! Thank you for your time!
Andrei