How LSTMCell implemented in C++?

Hello, I am confused about LSTMCell’s parameters and variables. In the docs, it says weight_ih - of shape (4*hidden_size, input_size) and input - of shape (batch, input_size). Obviously, they cannot be calculate like the equation i = σ(Wii * x + bii + Whi * h + bhi . So I guess some dimension transformation are completed in C++.
I want to know, besides dimension transformation, are there any other modifycations on the parameters and variables? Because when I implemented LSTMCell in JAX, and keep the same inputs, weights and bias as Pytorch’s, I got different outputs form Pytorch’s LSTMCell’s outputs.
Here is my LSTMCell’s implement. Note:self.Wi - of shape (input_size, hidden_size*4), self.Wh - of shape (hidden_size, hidden_size*4), b - of shape (hidden_size*4,)

def __call__(self, carry: Tuple, x: jnp.ndarray):
    h, c = carry
    gated = x @ self.Wi + self.bi
    gated += (h @ self.Wh + self.bh)
    
    i, g, f, o = jnp.split(gated, indices_or_sections=4, axis=-1)
    c = jax.nn.sigmoid(f) * c + jax.nn.sigmoid(i) * jax.nn.tanh(g)
    h = jax.nn.sigmoid(o) * jax.nn.tanh(c)
    return (h, c), h

Thanks!