Hello, I am confused about LSTMCell’s parameters and variables. In the docs, it says weight_ih - of shape (4*hidden_size, input_size)
and input - of shape (batch, input_size)
. Obviously, they cannot be calculate like the equation i = σ(Wii * x + bii + Whi * h + bhi
. So I guess some dimension transformation are completed in C++.
I want to know, besides dimension transformation, are there any other modifycations on the parameters and variables? Because when I implemented LSTMCell in JAX, and keep the same inputs, weights and bias as Pytorch’s, I got different outputs form Pytorch’s LSTMCell’s outputs.
Here is my LSTMCell’s implement. Note:self.Wi - of shape (input_size, hidden_size*4), self.Wh - of shape (hidden_size, hidden_size*4), b - of shape (hidden_size*4,)
def __call__(self, carry: Tuple, x: jnp.ndarray):
h, c = carry
gated = x @ self.Wi + self.bi
gated += (h @ self.Wh + self.bh)
i, g, f, o = jnp.split(gated, indices_or_sections=4, axis=-1)
c = jax.nn.sigmoid(f) * c + jax.nn.sigmoid(i) * jax.nn.tanh(g)
h = jax.nn.sigmoid(o) * jax.nn.tanh(c)
return (h, c), h
Thanks!