LayerNorm in LSTMCell

I found that nn.LayerNorm is available in PyTorch 0.4.0.

Can anyone please tell me how to employ nn.LayerNorm in multi-layered LSTMCell ?

Now, I am using nn.LayerNorm as like below…

self.lstm0 = nn.LSTMCell(input_size, hidden_size)
self.lstm1 = nn.LSTMCell(hidden_size, hidden_size)
self.lstm2 = nn.LSTMCell(hidden_size, hidden_size)
self.ln1 = nn.LayerNorm(hidden_size)
self.ln2 = nn.LayerNorm(hidden_size)

h_t0, c_t0 = self.lstm0(input, (h_t0, c_t0))
h_t1, c_t1 = self.lstm1(self.ln1(h_t0), (h_t1, c_t1))
h_t2, c_t2 = self.lstm1(self.ln2(h_t1), (h_t2, c_t2))


You have to implement it yourself. Take a look here