Is this minmal LSTM classifier correctly set up?

Hi,

I’ve written a basic LSTM classifier and a few more complex models based of this basic code. This basic model should take in a batch of sequences of size [200 x 128] and assign to each sequence one of 6 classes. I’m fairly sure this is correct and the model learns on the dataset. I’ve been looking around at other LSTM implementations and have seen that in some cases there are a few things included which I haven’t and I’m worried I’m missing something that could be effecting model training.

The basic model is as follows:

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self, num_classes=6, hidden_size=256, steps=200, bidirectional=False):
        super(Net, self).__init__()

        self.bidirectional = bidirectional
        self.hidden_size = hidden_size

        self.LSTM_one = nn.Sequential(
            nn.LSTM(input_size=128, hidden_size=hidden_size, num_layers=1, batch_first=True, bidirectional=bidirectional)
        )

        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(in_features=steps*hidden_size, out_features=500),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),

            nn.Linear(in_features=500, out_features=num_classes)
        )

    def forward(self, x):
        x, hidden = self.LSTM_one(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, hidden

I’ve seen in other places that a separate init function is called that initalizes the hidden state of the LSTM. Something like:

 def init_hidden(self):
    return (torch.zeros(1 + int(self.bidirectional), 1, self.hidden_size),
    torch.zeros(1 + int(self.bidirectional), 1, self.hidden_size))

and further ones that pass a hidden argument to the forward method like this:

  def forward(self, inputs, hidden):
    output, hidden = self.lstm(inputs.view(1, 1, self.input_size), hidden)
    return output, hidden

So my question is: are these two addition steps that my LSTM model is missing necessary or have they been deprecated?

Many thanks.