LSTM - Did I correctly change my forward function and my architecture to work bidirectional?

As I feel like there is very little help to this in the documentation or somewhere else I am just trying to figuere out if I did everything correctly. At first I only changed bidirectional to True and to my suprise my model did still run without any error message. Now I also changed the dimensions of the next linear layer and my forward method.

Here is my thought process:

  1. I am thinking, since bidirection doubles my lstm layer I need to double the Input size of the next layer. Which is weird because, like I said, after changing bidirectional to True without changing the input size of the next layer my model still ran.

  2. I need to take both outputs of my lstm hn[0] and hn[-1] and bot feed them to the next layer. So I then thought, maybe my model did not give me an error since I didnt have this implemented yet.

Would be just nice to know for sure how to correctly change the lstm structure. Any help appreciated.

EDIT: I just realised that ptorch documantation states
My LSTM Class now looks like this.

class Model_LSTM(nn.Module):

    def __init__(self, n_features, n_classes, n_hidden, n_layers):            # bidirectional möglich
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=n_features,
            hidden_size=n_hidden,
            num_layers=n_layers,
            batch_first=True,
            dropout=0,
            bidirectional=True
        )
        
        self.dense1 = nn.Linear(n_hidden*2, n_hidden)

        self.classifier = nn.Linear(n_hidden, n_classes)

        torch.nn.init.xavier_uniform_(self.lstm.weight_ih_l0)
        torch.nn.init.xavier_uniform_(self.lstm.weight_hh_l0)

        torch.nn.init.xavier_uniform_(self.classifier.weight)
        torch.nn.init.xavier_uniform_(self.dense1.weight)
        

    def forward(self, x):
        _, (hn, _) = self.lstm(x)                  
        out=torch.concat([hn[0], hn[-1]], dim=1)       
        out = F.relu(self.dense1(out))
        return self.classifier(out)