Is my lstm wrong?

I make a lstm step by step,but I found it not same with the offical one.
I got a different results,only input one line data and init with the same params,simple such as:

# official
import torch
import torch.nn as nn
import torch.nn.functional as F



torch.manual_seed(20)
lstm_official = torch.nn.LSTM(6, 2, bidirectional=False, num_layers=1, batch_first=False)
share_weight = torch.randn(lstm_official .weight_ih_l0.shape,dtype = torch.float)
lstm_official .weight_ih_l0 = torch.nn.Parameter(share_weight)
# bias set to zeros
lstm_official .bias_ih_l0 = torch.nn.Parameter(torch.zeros(lstm_official .bias_ih_l0.shape))
lstm_official .weight_hh_l0 = torch.nn.Parameter(torch.ones(lstm_official .weight_hh_l0.shape))
# bias set to zeros
lstm_official .bias_hh_l0 = torch.nn.Parameter(torch.zeros(lstm_official .bias_ih_l0.shape))
x = torch.tensor([[1,2,3,4,5,6],[1,2,3,4,5,6]],dtype=torch.float)
lstm_official_out = lstm_official (x[0].unsqueeze(dim=0).unsqueeze(dim=0))

# implementation step by step
out_shape=2
batchsize=1
i2h = nn.Linear(in_features=6, out_features=8)  #
h2h = nn.Linear(in_features=out_shape, out_features=8)
i2h.weight = torch.nn.Parameter(share_weight)
i2h.bias = torch.nn.Parameter(torch.zeros(i2h.bias.shape))
h2h.weight = torch.nn.Parameter(torch.ones(h2h.weight.shape))
h2h.bias = torch.nn.Parameter(torch.zeros(h2h.bias.shape))
x_i2h = i2h(x[0].unsqueeze(dim=0)) 
prev_h = torch.zeros((batchsize,2))
prev_c = torch.zeros((batchsize,2))
x_h2h = h2h(prev_h)
gates = x_i2h + x_h2h
gates = torch.split(gates,out_shape,-1)
in_gate = torch.sigmoid(gates[0])  
in_transform = torch.tanh(gates[1])
forget_gate = torch.sigmoid(gates[2]) 
out_gate = torch.sigmoid(gates[3]) 
print(in_gate,in_transform,forget_gate,out_gate)
s0 = forget_gate * prev_c
s1 = in_gate * in_transform
next_c = s0 + s1
next_h = out_gate * F.tanh(next_c)  


print(f'official:{lstm_official_out[0]}')
print(f'step:{next_h}')


I got:

official:tensor([[[0.0865, 0.0915]]], grad_fn=)

step:tensor([[-0.0745, -0.0915]], grad_fn=)

I think you are mixing up in_transform and forget_gate. This code should work and I’ve also added my manual implementation, which basically just reuses the formulas from the docs:

# nn.LSTM 
torch.manual_seed(20)
lstm_official = torch.nn.LSTM(6, 2, bidirectional=False, num_layers=1, batch_first=False)
share_weight = torch.randn(lstm_official .weight_ih_l0.shape,dtype = torch.float)
lstm_official .weight_ih_l0 = torch.nn.Parameter(share_weight)
# bias set to zeros
lstm_official .bias_ih_l0 = torch.nn.Parameter(torch.zeros(lstm_official .bias_ih_l0.shape))
lstm_official .weight_hh_l0 = torch.nn.Parameter(torch.ones(lstm_official .weight_hh_l0.shape))
# bias set to zeros
lstm_official .bias_hh_l0 = torch.nn.Parameter(torch.zeros(lstm_official .bias_ih_l0.shape))
x = torch.tensor([[1,2,3,4,5,6],[1,2,3,4,5,6]],dtype=torch.float)
lstm_official_out = lstm_official (x[0].unsqueeze(dim=0).unsqueeze(dim=0))

# manual implementation
W_ii, W_if, W_ig, W_io = lstm_official.weight_ih_l0.split(2, dim=0)
b_ii, b_if, b_ig, b_io = lstm_official.bias_ih_l0.split(2, dim=0)

W_hi, W_hf, W_hg, W_ho = lstm_official.weight_hh_l0.split(2, dim=0)
b_hi, b_hf, b_hg, b_ho = lstm_official.bias_hh_l0.split(2, dim=0)

input = x[0].unsqueeze(0)
prev_h = torch.zeros((batchsize,2))
prev_c = torch.zeros((batchsize,2))

i_t = torch.sigmoid(F.linear(input, W_ii, b_ii) + F.linear(prev_h, W_hi, b_hi))
f_t = torch.sigmoid(F.linear(input, W_if, b_if) + F.linear(prev_h, W_hf, b_hf))
g_t = torch.tanh(F.linear(input, W_ig, b_ig) + F.linear(prev_h, W_hg, b_hg))
o_t = torch.sigmoid(F.linear(input, W_io, b_io) + F.linear(prev_h, W_ho, b_ho))
c_t = f_t * prev_c + i_t * g_t
h_t = o_t * torch.tanh(c_t)

print('nn.LSTM output {}, manual output {}'.format(lstm_official_out[0], h_t))
print('nn.LSTM hidden {}, manual hidden {}'.format(lstm_official_out[1][0], h_t))
print('nn.LSTM state {}, manual state {}'.format(lstm_official_out[1][1], c_t))

# implementation step by step
out_shape=2
batchsize=1
i2h = nn.Linear(in_features=6, out_features=8)  #
h2h = nn.Linear(in_features=out_shape, out_features=8)
i2h.weight = torch.nn.Parameter(share_weight)
i2h.bias = torch.nn.Parameter(torch.zeros(i2h.bias.shape))
h2h.weight = torch.nn.Parameter(torch.ones(h2h.weight.shape))
h2h.bias = torch.nn.Parameter(torch.zeros(h2h.bias.shape))

x_i2h = i2h(x[0].unsqueeze(dim=0)) 
prev_h = torch.zeros((batchsize,2))
prev_c = torch.zeros((batchsize,2))
x_h2h = h2h(prev_h)
gates = x_i2h + x_h2h
gates = torch.split(gates,out_shape,-1)
in_gate = torch.sigmoid(gates[0])  
in_transform = torch.tanh(gates[2])
forget_gate = torch.sigmoid(gates[1]) 
out_gate = torch.sigmoid(gates[3]) 
print(in_gate,in_transform,forget_gate,out_gate)
s0 = forget_gate * prev_c
s1 = in_gate * in_transform
next_c = s0 + s1
next_h = out_gate * F.tanh(next_c)  

print(f'official:{lstm_official_out[0]}')
print(f'step:{next_h}')
1 Like

It works well,you are so cool.it’s my fault