Stability issue to replicate the behavior of torch LSTM

Hello guys,
I have implemented a custom LSTM in PyTorch to replicate the behavior of PyTorch’s built-in nn.LSTM module. When I initialize the weights with specific values (using a normal distribution with mean -0.0346 and std deviation 16.017), the outputs from my custom LSTM and PyTorch’s LSTM start to diverge significantly after a couple of timesteps. I have confirmed that both LSTMs produce nearly identical outputs when the weights are initialized with a smaller standard deviation (e.g., weights scaled by 0.01).

Of course, this issue might be due to the bad initialization values causing instability. However, the main question is why the outputs are not the same despite the custom LSTM being implemented to match PyTorch’s LSTM operations closely. It is crucial to identify whether the divergence is due to differences in the implementation or inherent numerical stability issues.

Here is the code for both LSTM implementations and the comparison:

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(30)


use_GPU = True
if use_GPU:
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print("Using device:", device)
else:
  device = torch.device("cpu")
  
  
  
# Activation functions
def activation(x):
    return torch.tanh(x)

def recurrent_activation(x):
    return torch.sigmoid(x)

# custom LSTM implementation
class CustomLSTM(nn.Module):
    def __init__(self, weights):
        super(CustomLSTM, self).__init__()
        weight_ih_l0, weight_hh_l0, bias_ih_l0, bias_hh_l0 = weights
        self.w = torch.nn.Parameter(torch.tensor(weight_ih_l0, dtype=torch.float32))
        self.w_rec = torch.nn.Parameter(torch.tensor(weight_hh_l0, dtype=torch.float32))
        self.bias = torch.nn.Parameter(torch.tensor(bias_ih_l0, dtype=torch.float32))
        self.bias_rec = torch.nn.Parameter(torch.tensor(bias_hh_l0, dtype=torch.float32))
        self.units = weight_ih_l0.shape[0] // 4

    def forward(self, inputs):
        batch_size, seq_len, _ = inputs.shape
        zero_state_h = torch.zeros(batch_size, self.units).to(inputs.device)
        zero_state_c = torch.zeros(batch_size, self.units).to(inputs.device)
        h_tm1 = zero_state_h
        c_tm1 = zero_state_c
        outputs = []
        h_states = []
        c_states = []
        for i in range(seq_len):
            ts_data = inputs[:, i, :]
            res = torch.matmul(ts_data, self.w.t()) + self.bias
            rec = torch.matmul(h_tm1, self.w_rec.t()) + self.bias_rec
            z = res + rec
            z0 = z[:, :self.units]
            z1 = z[:, self.units: 2 * self.units]
            z2 = z[:, 2 * self.units: 3 * self.units]
            z3 = z[:, 3 * self.units:]
            i = recurrent_activation(z0)
            f = recurrent_activation(z1)
            g = activation(z2)
            o = recurrent_activation(z3)
            c = f * c_tm1 + i * g
            h = o * activation(c)
            outputs.append(h)
            h_tm1, c_tm1 = h, c
            h_states.append(h_tm1)
            c_states.append(c_tm1)
            
        outputs = torch.stack(outputs, dim=1)
        h_states = torch.stack(h_states, dim=1)
        c_states = torch.stack(c_states, dim=1)
        return outputs, (h_tm1, c_tm1), h_states, c_states


# Define device and hyperparameters
batch_size = 1
seq_len = 100
input_size = 32
output_size = 64

# Reasonable weight initialization
# w_rand = torch.randn(input_size, 4 * output_size) * 0.01
# w_rand_rec = torch.randn(output_size, 4 * output_size) * 0.01

# To reproduce instability
w_rand     = torch.normal(-0.0346, 16.017, size=(input_size, 4 * output_size))
w_rand_rec = torch.normal(-0.0346, 16.017, size=(output_size, 4 * output_size))


# Construct an LSTM layer
lstm_layer = torch.nn.LSTM(input_size, output_size, num_layers=1, bias=True, batch_first=True).to(device)


# Transpose and set weights
w_rand = w_rand.detach().cpu().numpy().T
w_rand_rec = w_rand_rec.detach().cpu().numpy().T
with torch.no_grad():
    lstm_layer.weight_ih_l0.copy_(torch.tensor(w_rand, device=device))
    lstm_layer.weight_hh_l0.copy_(torch.tensor(w_rand_rec, device=device))
    lstm_layer.bias_ih_l0.copy_(torch.zeros((output_size*4,)).to(device)) # put away biases for now
    lstm_layer.bias_hh_l0.copy_(torch.zeros((output_size*4,)).to(device)) # put away biases for now


# Extract weights and biases
weight_ih_l0 = lstm_layer.weight_ih_l0.detach().cpu().numpy()
weight_hh_l0 = lstm_layer.weight_hh_l0.detach().cpu().numpy()
bias_ih_l0 = lstm_layer.bias_ih_l0.detach().cpu().numpy()
bias_hh_l0 = lstm_layer.bias_hh_l0.detach().cpu().numpy()
lstm_v2_weights = [weight_ih_l0, weight_hh_l0, bias_ih_l0, bias_hh_l0]

# Initialize CustomLSTM
CustomLSTM_layer = CustomLSTM(lstm_v2_weights).to(device)

# Create a random input tensor
input_tensor = torch.randn(batch_size, seq_len, input_size).to(device)

# Get outputs from both LSTM layers
with torch.no_grad():
  output_ref, (ref_h_t, ref_c_t) = lstm_layer(input_tensor)
  output_custom, (h_t, c_t), hidden_states, cell_states= CustomLSTM_layer(input_tensor)

print("*"*45)
print(f"Ref stats (mean, std): ({torch.mean(output_ref).item():.6f}, {torch.std(output_ref).item():.6f})")
print(f"Cus stats (mean, std): ({torch.mean(output_custom).item():.6f}, {torch.std(output_custom).item():.6f})")
print("*"*45)

Thanks in advance for your help,