Batch normalization and weight initialization in Seq2Seq

Hello PyTorch!
I want to apply Batch Normalization and Weight initialization in Seq2Seq problem but,
i don’t know how to apply it in my model.

Please help me guys!
Those are my model below.


class Encoder(nn.Module):

def __init__(self,
             input_dim,
             hidden_dim,
             num_layers=1,
             dropout=0,
             bidirectional=False):
    super(Encoder, self).__init__()
    self.encoder = nn.LSTM(input_dim,
                           hidden_dim,
                           num_layers=num_layers,
                           dropout=dropout,
                           bidirectional=bidirectional)

def forward(self, x, hidden):
    encoder_output, encoder_state = self.encoder(x, hidden)
    return encoder_output, encoder_state

class Decoder(nn.Module):

def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, dropout=0, bidirectional=False):
    super(Decoder, self).__init__()
    self.decoder = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional)
    self.linear = nn.Linear(hidden_dim, output_dim)

def forward(self, x, hidden):
    decoder_output, next_hidden = self.decoder(x, hidden)
    
    outputs = []
    for i in range(decoder_output.size()[1]):
        outputs += [self.linear(decoder_output[:, i, :])]
    return torch.stack(outputs, dim=1).squeeze(), decoder_output, next_hidden

class Model(nn.Module):

def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, output_length=3):
    super(Model, self).__init__()
    self.encoder = Encoder(input_dim, hidden_dim, num_layers=num_layers)
    self.decoder = Decoder(hidden_dim, hidden_dim, output_dim, num_layers=num_layers)
    self.output_length = output_length
    self.num_layers = num_layers
    self.hidden_dim = hidden_dim
    
def forward(self, x):
    encoder_output, encdoer_state = self.encoder(x, None) 
    decoder_input = torch.unsqueeze(encoder_output[-1], 0)
    
    seq = []
    next_hidden=None 
    next_input = decoder_input
    
    for _ in range(self.output_length):
        output, next_input, next_hidden = self.decoder(next_input, next_hidden)
        seq += [output]
    return torch.stack(seq, dim=0).squeeze(), torch.unsqueeze(encoder_output[-1],0)