Correct way to feed data to RNN in PyTorch

I have a data sequence a which is of shape [seq_len, 2], seq_len is the length of the sequence. There is time correlation among elements of a[:, 0] and a[:, 1], but a[:, 0] and a[:, 1] are independent of each other. For training I prepare data of shape [batch_size, seq_len, 2]. The initialization of BRNN that I use is

birnn_layer = nn.RNN(input_size=2, hidden_size=100, batch_first=True, bidirectional=True)

From the docs,

input_size – The number of expected features in the input x

hidden_size – The number of features in the hidden state h

What does “number of expected features” mean? Since there is correlation along the seq_len axis should my input_size be set as seq_len and the input be permuted? Thanks.

Hi,

input_size or “no. of expected features” denotes the dimensionality of each observation; in this case, 2.

Also, your input to nn.RNN should be in the shape of [seq_len, batch_size, input_size]. At every timestep, the RNN receives a [t, :, :] matrix that contains all the observations at timestep t from all the batches.

Hi,

I have mentioned batch_size=True as one of the parameters, so the batch_size dimension should come first right?

Oh, I didn’t see the arg - yes that’s correct.

Can you please elaborate on this?

Let’s say your batch_size=5 and seq_len=3. So each batch looks like

batch = [
        [x1_1, x1_2, x1_3],
        [x2_1, x2_2, x2_3],
        ...
        [x5_1, x5_2, x5_3]
       ] # shape (batch_size, seq_len, input_size) 

where x{seq_id}_{timestep}

For a given timestep t, the RNN reads the t-th observations from all the sequences in the batch. So at timestep t, it will read and process input_t

input_1 = [x1_1, x2_1..., x5_1] # shape (batch_size, input_size)
input_2 = [x1_2, x2_2..., x5_2]
input_3 = [x1_3, x2_3..., x5_3]

However, you don’t actually need to split your data into time-stepped inputs like this, the RNN just needs

rnn_input = [input_1, input_2, input_3] # (seq_len, batch_size, input_size)

from you.

Hope this helps!

1 Like

I have one last question. In the above example what would be the shape of data? Thank you so much for your effort

edited my answer, hopefully this is clearer

Perfect thanks. This helped.