Suppose (it is a toy example) we want to implement an RNN that maps an input sequence $x_t$ to an output sequence $y_t$ via hidden states $h_t$.
We assume the outputs to be scalars. They are computed as $y_t=F(x_t,h_{t-1})$ with a scalar-valued feedforward neural network $F$.
Let us further assume that the hidden states should be computed as
$$ h_t = \frac{\partial F}{\partial h}(x_t,h_{t-1}). $$
How can this derivative be computed in the forward pass?
The problem is that the derivative is w.r.t. the non-leaf node $h_{t-1}$.
So we cannot simply do:
def RNN_step(x, h_prev):
h_prev.requires_grad = True # this does not work
y = F(x, h_prev)
y.sum().backward(create_graph=True) # summing over the batch dimension
h = h_prev.grad
h_prev.requires_grad = False
return y, h