Last hidden state in bidirectional stacked GRU

Hello. I am building BiGRU for the classification purposes. I decided to use max-polling and average pooling in my model, and concatenate them both with last hidden state. Could you please explain to me what is the recommended approach when dealing with last hidden state from stacked bidirectional models?

Layers that I use:

        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.embedding_dim, self.hidden_size, num_layers=self.n_layers, 
                          dropout=(0 if n_layers == 1 else self.dropout_p), batch_first=True,
                          bidirectional=self.bidirectional)
        # Linear layer input size is equal to hidden_size*n_directions * 3, becuase
        # we will concatenate max_pooling ,avg_pooling and last hidden state
        self.linear = nn.Linear(self.hidden_size * self.n_directions * 2 + self.hidden_size,
                                self.output_size)

The forward propagation function:

        self.batch_size = input_seq.size(0)
        
        # Embeddings shapes
        # Input: (batch_size,  seq_length)
        # Output: (batch_size, seq_length, embedding_dim)
        emb_out = self.embedding(input_seq)
        emb_out = self.dropout(emb_out)
        
        # Pack padded batch of sequences for RNN module
        packed_emb = nn.utils.rnn.pack_padded_sequence(emb_out, input_lengths, batch_first=True)
                
        # GRU input/output shapes, if batch_first=True
        # Input: (batch_size, seq_len, embedding_dim)
        # Output: (batch_size, seq_len, hidden_size*num_directions)
        # Number of directions = 2 when used bidirectional, otherwise 1
        # shape of hidden: (n_layers x num_directions, batch_size, hidden_size)
        # Hidden state defaults to zero if not provided
        gru_out, hidden = self.gru(packed_emb, hidden)
        # gru_out: tensor containing the output features h_t from the last layer of the GRU
        # gru_out comprises all the hidden states in the last layer ("last" depth-wise, not time-wise)
        # For biGRu gru_out is the concatenation of a forward GRU representation and a backward GRU representation
        # hidden (h_n) comprises the hidden states after the last timestep
        
        # Pad a packed batch
        # Output: (batch_size, seq_len, hidden_size*num_directions)
        gru_out, _ = nn.utils.rnn.pad_packed_sequence(gru_out, batch_first=True)
        
        # Select the maximum value over each dimension of the hidden representation (max pooling)
        # Permute the input tensor to dimensions: (batch_size, hidden*num_directions, seq_len)
        # Output dimensions: (batch_size, hidden_size*num_directions)
        max_pool = F.adaptive_max_pool1d(gru_out.permute(0,2,1), (1,)).view(self.batch_size,-1)
        
        # Consider the average of the representations (mean pooling)
        # Sum along the batch axis and divide by the corresponding lengths (FloatTensor)
        # Output shape: (batch_size, hidden_size*num_directions)
        avg_pool = torch.sum(gru_out, dim=1) / input_lengths.view(-1,1).type(torch.FloatTensor) 

        # Concatenate max_pooling, avg_pooling and last hidden state tensors
        concat_out = torch.cat([hidden[-1], max_pool, avg_pool], dim=1)

        concat_out = self.dropout(concat_out)
        out = self.linear(concat_out)
        return F.log_softmax(out, dim=-1)

What is the bast way to use hidden representation and hidden state of GRU?

In the preceding implementation I pass gru_out with the shape of (batch_size, seq_len, hidden_size * num_directions) to the F.adaptive_max_pool1d that returns (batch_size, hidden_size * num_directions). I am wondering if it is better to firstly sum the gru_out:

gru_out = (gru_out[:, :, :self.hidden_size] + gru_out[:, :, self.hidden_size:])

Doing this I can reduce the number of input dimensions in self.linear.

What about when the number of layers is equal 2, or more? Should I sum gru_out as above and then sum hidden along n_layers x num_directions, or simply use hidden[-1] but that cause that we get rid of first dimension, so some information will be lost, I guess.

I am not sure at which step, and what I should sum or concatenate to make this properly.

Any advices would be appreciated.

1 Like

If you’re interested in the last hidden state, i.e., the hidden state after the last time step, I wouldn’t bother with gru_out and simply use hidden (w.r.t. to your examples). According to the docs:

  • hidden.shape = (num_layers*num_directions, batch, hidden_size)
  • layers can be separated using h_n.view(num_layers, num_directions, batch, hidden_size)

So you shouldn’t simply do hidden[-1] but first do a view() to separate the num_layers and num_directions (1 or 2). If you do

hidden = hidden.view(num_layers, 2, batch, hidden_size) # 2 for bidirectional
last_hidden = hidden[-1]

then last_hidden.shape = (2, batch, hidden_size) and you can do

last_hidden_fwd = last_hidden[0]
last_hidden_bwd = last_hidden[1]

If you want to sum them up or concatenate them is up to you. You can also have a look at this related post.

4 Likes