If you look at the docs, the shape of h_n
is (num_directions*num_layers, batch_size, hidden_dim)
. This means that
c_n_merged = c_n.reshape(self.batch_size, -1)
will mess up your data (see also here). Also, you want to use h_n
not c_n
for further processing.
# Separate num_layers and num_directions
h_n = h_n.view(num_layers, num_directions, batch_size, hidden_dim)
# Get last hidden state w.r.t. number of layers
h_last = h_n[-1]
# Handle both direction be concatenating the 2 respective last hidden states
h_last = torch.cat((h_last[0], h_last_[1]), 1)
Now h_last
should have a shape of (batch_size, 2*hidden_dim)
. This means that you then also need to change the definition of the first linear layer to:
self.linear_layer_1 = nn.Linear(2*self.hidden_dim, self.in_4_dim)
I’m actually surprised that your code doesn’t through an error, but I might very well have missed something.