I would like to ask how the hidden states produced by a Bidirectional RNN are concatenated.
If I’m not mistaken, the output parameter of a PyTorch RNN is of shape (N, T, 2*H) given, that the ‘batch_first’ and ‘bidirectional’ parameters have been set to True [N: number of examples, T: number of time steps, H: cell size].
Are the two 3-D tensors concatenated on the last axis? Or is each pair of tensors computed on a given time step “put” next to each other?
I would like to merge the hidden states of the Bidirectional RNN (to produce an output of shape (N, T, H)). If I’m correct there is a difference between torch.split() and torch.chunk() functions.
Thank you for your help in advance!
You’re not mistaken – the output parameter of all the PyTorch recurrent units (assuming batch_first=True) is, when used bidirectionally:
(num_examples, seq_len, 2 * hidden_dim)
The two 3D tensors are actually concatenated on the last axis, so to merge them, we usually do something like this:
output = output[:, :, :self.hidden_dim] + output[:, :, self.hidden_dim:]
You might also try averaging them (by dividing the merged hidden state by 2).
Alternately (and this is common) you can just use the concatenated hidden state as is (I assume you’re using the hidden state as a context vector to condition a decoder?) – this is quite common, I think.
This discussion is handy, and helped me out when I couldn’t quite figure out the documentation on PyTorch’s bidirectional RNNs.
Lastly, if you want a fixed-length summary of the hidden state, you can apply L2 pooling to the whole thing. I believe this is the component-wise root-mean-square of all the hidden states – in other words, component-wise square each hidden state at each time step; average them all together (i.e., sum them and divide by the sequence length); take the component-wise square root of the result.