# RNN loop uses previous time step's hidden state to predict current output in the PyTorch tutorial

``````class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()

self.hidden_size = hidden_size

self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)

def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
``````

Hi, I’ve just begun using Pytorch and was going through the RNN example in https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial. As per my understanding, the current time step’s output is predicted using the current hidden state. But here, the previous time step’s hidden state seems to be used. Can I get an explanation? Thanks.

It may be a bit confusing, but it is shifted. The reason is, before the first step, you already have one a priori hidden state, and first step looke like rnn(input, prior_hidden) ->hidden. This way, tensor sizes match in time dimension.

1 Like

There are several variables used in a typical RNN.

1. Current hidden state: `h_t`
2. Last hidden state: `h_t-1`
3. Current input: `x_t`
4. Current output: `o_t`

In this example, `h_t` is computed using `h_t-1` and `x_t`. `o_t` is computed using the `h_t-1` and `x_t`.

1 Like
``````def rnn_cell_forward(xt, a_prev, parameters):

a_next = np.tanh(np.dot(Waa,a_prev)+np.dot(Wax,xt)+ba)
yt_pred = softmax(np.dot(Wya,a_next)+by)

cache = (a_next, a_prev, xt, parameters)
return a_next, yt_pred, cache
``````

These are the forward pass equations I’m most familiar with. At the very first time step, the initial hidden state a_prev (h initialized to 0) is never used to make output predictions. Whereas in the above PyTorch tutorial, a_prev gets to make the first output prediction. Are these changes in architecture/implementation? sorry, I’m still sort of confused.

The implementation you posted above is different from the implementation in the Pytorch tutorial. I think your understanding is correct in both cases.

Thanks!! I’ll continue to think it’s just a difference in implementation.