I tried to build my own relu lstm here. could u have a look for my solution ? Is this build correct ? thank you.
import torch
import torch.nn as nn
class CustomLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(CustomLSTMCell, self).__init__()
self.hidden_size = hidden_size
# LSTM gates: input, forget, cell, and output
self.w_i = nn.Linear(input_size + hidden_size, hidden_size)
self.w_f = nn.Linear(input_size + hidden_size, hidden_size)
self.w_c = nn.Linear(input_size + hidden_size, hidden_size)
self.w_o = nn.Linear(input_size + hidden_size, hidden_size)
# Activation functions
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x, h_prev, c_prev):
combined = torch.cat((x, h_prev), dim=1)
i_t = self.sigmoid(self.w_i(combined)) # Input gate
f_t = self.sigmoid(self.w_f(combined)) # Forget gate
c_tilde = self.relu(self.w_c(combined)) # ReLU activation instead of tanh
c_t = f_t * c_prev + i_t * c_tilde # Cell state update
o_t = self.sigmoid(self.w_o(combined)) # Output gate
h_t = o_t * self.relu(c_t) # New hidden state using ReLU
return h_t, c_t
class CustomLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(CustomLSTM, self).__init__()
self.num_layers = num_layers
# Create multiple LSTM layers
self.lstm_cells = nn.ModuleList([
CustomLSTMCell(input_size if i == 0 else hidden_size, hidden_size)
for i in range(num_layers)
])
def forward(self, x, h_prev, c_prev):
h_next_list, c_next_list = [], []
for i in range(self.num_layers):
h_next, c_next = self.lstm_cells[i](x, h_prev[i], c_prev[i])
x = h_next # Pass output to next layer
h_next_list.append(h_next)
c_next_list.append(c_next)
return torch.stack(h_next_list), torch.stack(c_next_list)
# Example usage
input_size = 10
hidden_size = 20
num_layers = 3
lstm_model = CustomLSTM(input_size, hidden_size, num_layers)
x = torch.randn(1, input_size)
h_prev = torch.randn(num_layers, 1, hidden_size) # Initialize all layers
c_prev = torch.randn(num_layers, 1, hidden_size)
h_next, c_next = lstm_model(x, h_prev, c_prev)
print("Next Hidden States:", h_next)
print("Next Cell States:", c_next)