Example of using Normalization with LSTM

class LSTMExample(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(LSTMExample, self).__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=2)
        self.linear = nn.Linear(self.hidden_dim, vocab_size)

    def forward(self, input, hidden=None):
        seq_len, batch_size = input.size()
        if hidden is None:
            h_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
            c_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
            h_0, c_0 = Variable(h_0), Variable(c_0)
        else:
            h_0, c_0 = hidden
            
        embeds = self.embeddings(input)
        output, hidden = self.lstm(embeds, (h_0, c_0))
        output = self.linear(output.view(seq_len * batch_size, -1))

        return output, hidden

Hi, I am a newcomer using LSTM.

I want to change the model in order to make it work better, and as far as i know, normalization is a good way to make it work.

But i wonder whether the Norm-Functions(BatchNorm1d、BatchNorm2d、BatchNorm3d、GroupNorm、InstanceNorm1d、InstanceNorm2d、InstanceNorm3d、LayerNorm、LocalResponseNorm) in pytorch is suitable for lstm cause some people say normal BN does not work in RNN.

Thanks so much!