Implementing an LSTM from scratch leads to different results

Hello,
I am implementing an LSTM from scratch and then comparing it with the PyTorch LSTM, however, the results I get when using the PyTorch LSTM are better than my LSTM implementation. May I know what is wrong in the code below? Or what can be adjusted to get the correct result?

class LSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.x2h = nn.Linear(input_size, 4 * hidden_size)
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size)
        self.tanh = nn.Tanh()
        self.init_parameters()
    
    def init_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for p in self.parameters():
            p.data.uniform_(-std, std)

    def forward(self, inp, states):
        """
        inp shape: (batch_size, input_size)
        each of states shape: (batch_size, hidden_size)
        """
        
        ht, ct = states
        gates = self.x2h(inp) + self.h2h(ht)    # (batch_size, 4 * hidden_size)

        in_gate, forget_gate, new_memory, out_gate = gates.chunk(4, 1)

        in_gate = torch.sigmoid(in_gate)
        forget_gate = torch.sigmoid(forget_gate)
        out_gate = torch.sigmoid(out_gate)
        new_memory = self.tanh(new_memory)
        c_new = (forget_gate * ct) + (in_gate * new_memory)
        h_new = out_gate * self.tanh(c_new)

        return (h_new, c_new)

Thanks!

@Fawaz_Sammani can you help please? I believe you do understand in this since your answer to my last post is very related. Hope to hear from you :slight_smile:

It seems correct to me. I’ve used the same code in one of my projects before (just different variables) and I got very close results with the PyTorch LSTM. May you please post your results and code when using the PyTorch LSTM?