RNN loop uses previous time step's hidden state to predict current output in the PyTorch tutorial

class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
    super(RNN, self).__init__()

    self.hidden_size = hidden_size

    self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
    self.i2o = nn.Linear(input_size + hidden_size, output_size)
    self.softmax = nn.LogSoftmax(dim=1)

def forward(self, input, hidden):
    combined = torch.cat((input, hidden), 1)
    hidden = self.i2h(combined)
    output = self.i2o(combined)
    output = self.softmax(output)
    return output, hidden

Hi, I’ve just begun using Pytorch and was going through the RNN example in https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial. As per my understanding, the current time step’s output is predicted using the current hidden state. But here, the previous time step’s hidden state seems to be used. Can I get an explanation? Thanks.

It may be a bit confusing, but it is shifted. The reason is, before the first step, you already have one a priori hidden state, and first step looke like rnn(input[0], prior_hidden) ->hidden[0]. This way, tensor sizes match in time dimension.

1 Like

There are several variables used in a typical RNN.

  1. Current hidden state: h_t
  2. Last hidden state: h_t-1
  3. Current input: x_t
  4. Current output: o_t

In this example, h_t is computed using h_t-1 and x_t. o_t is computed using the h_t-1 and x_t.

1 Like
def rnn_cell_forward(xt, a_prev, parameters):

  a_next = np.tanh(np.dot(Waa,a_prev)+np.dot(Wax,xt)+ba)
  yt_pred = softmax(np.dot(Wya,a_next)+by)   

  cache = (a_next, a_prev, xt, parameters)
  return a_next, yt_pred, cache

These are the forward pass equations I’m most familiar with. At the very first time step, the initial hidden state a_prev (h[0] initialized to 0) is never used to make output predictions. Whereas in the above PyTorch tutorial, a_prev gets to make the first output prediction. Are these changes in architecture/implementation? sorry, I’m still sort of confused.

The implementation you posted above is different from the implementation in the Pytorch tutorial. I think your understanding is correct in both cases.

Thanks!! I’ll continue to think it’s just a difference in implementation.