Input shape to GRU layer

Greetings. I am getting confused about the input shape to GRU layer.
I have one batch of 128 images and I extracted 9 features from each images.
So now my shape is (1,128,9).
Here 1 is batch, 128 images in each batch and 9 features of each images. The images is in sequence, for example 128 frame of a video. So simply one batch represent one video.

This is the GRU layer


Question 1: Is the input_size=128 correctly defined?
Here is the code of forward function

def forward(features):
    features=features.permute(0,2,1)#[1, 9, 128]

Question 2: Is the code in forward function is correctly defined?

Question 3: In shape (1,128,9), 9 is sequence length and 128 is the input. Is this correct or its the opposite way?


You also need a hidden state of size (num_layers, batch_size, output_size), which can start out as zeros and evolve at each time step.

See here for an example:

I saw you updated your question. This is the example given on the docs page:

rnn = nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)

Note that you’re still missing the hidden state h0 in your code.

The GRU layer expects two inputs. The hidden input is how you convey contextual information to the model between timesteps.

Now onto your questions.

Q1. No, that is incorrect to have batch size at dim=1 when you have clearly set batch_first=True. You have to choose how you want to send your data in and get it out, then set that argument accordingly. If you set batch_first=True, your input size should be (128, 1, 9).

Q2. No. You’re not giving the model a hidden state.

Q3. The order can be N, L, H with batch_first=True or L, N, H if False where L is length, N is batch size and H is number of features.

Last issue, when you instantiated the GRU class, you put the input size as 128, which you have indicated to be your batch size. That is incorrect. Input size should equal the number of features, in this case, 9.

input_size refers to the number of features, in your case 9. You don’t specify the sequence length when defining an nn.GRU or nn.LSTM. Recurrent Neural Networks go through each item in your sequence no matter how long.

hidden_size is the size of the internal state of your nn.GRU. You probably want to set it to a larger value than 9 :), is it might have to capture a lot of information.

So you should define your recurrent layer as follows:

gru=torch.nn.GRU(input_size=9, hidden_size=256, batch_first=True) # The 256 is just a suggestion

Since your set batch_first=True, you gru layer expects an input of (batch_size, seq_len, input_size), which in your example would be (1, 128, 9). As this seems already to be your shape, there should be no need to permute the dimensions.

@J_Johnson You don’t have to explicitly give an initial hidden state as input – although it’s probably a good practice to always do so.

If you don’t specify an initial hidden state, it gets initialized with a 0-vector. You can check the source code. The signature for the forward() method is

def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:

So hx is optional. If it is None, then it gets initialized directly in the forward() method as follows: hx = torch.zeros(...)

I just assume someone making use of a GRU with a sequence length of 1 intends to pass in later timesteps via their model’s forward method. So while that might not give any errors, it sure would negate any benefit of using an RNN class layer.

I thought 1 is the batch_size is the seq_len. But, yeah, no reading “I have a batch of 128 images” again, I’m not sure sure anymore :slight_smile:

1 Like

Hi, Thanks for the detailed answer. sorry my bad. actually my batch size is 1, each batch has 128 images, and each image gas 9 features.
Second thing if we dont defined hidden state, wont it take it automatically, because it did not throw any error? thanks

apology my side. my batch size is 1, each batch has 128 images and each image has 9 features. for example if i have 2 batch size then the shape would be (2,128,9)