Hello,
I am implementing an LSTM from scratch and then comparing it with the PyTorch LSTM, however, the results I get when using the PyTorch LSTM are better than my LSTM implementation. May I know what is wrong in the code below? Or what can be adjusted to get the correct result?
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(LSTMCell, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.x2h = nn.Linear(input_size, 4 * hidden_size)
self.h2h = nn.Linear(hidden_size, 4 * hidden_size)
self.tanh = nn.Tanh()
self.init_parameters()
def init_parameters(self):
std = 1.0 / math.sqrt(self.hidden_size)
for p in self.parameters():
p.data.uniform_(-std, std)
def forward(self, inp, states):
"""
inp shape: (batch_size, input_size)
each of states shape: (batch_size, hidden_size)
"""
ht, ct = states
gates = self.x2h(inp) + self.h2h(ht) # (batch_size, 4 * hidden_size)
in_gate, forget_gate, new_memory, out_gate = gates.chunk(4, 1)
in_gate = torch.sigmoid(in_gate)
forget_gate = torch.sigmoid(forget_gate)
out_gate = torch.sigmoid(out_gate)
new_memory = self.tanh(new_memory)
c_new = (forget_gate * ct) + (in_gate * new_memory)
h_new = out_gate * self.tanh(c_new)
return (h_new, c_new)
Thanks!