Setting relu in lstm?

        self.encoder1 = nn.LSTM(units, units, batch_first=True, num_layers=layers, dropout=0.2)
        self.encoder1.activation = nn.ReLU()

after some searching, i am confused about if set the activation function from tanh to relu. is this correctly ? i see someone said that need to be self implemented. But i can still running this code.

Manipulating internal attributes after the module creation might not have any effects as the backend could already be selected. E.g. in your use case you could profile the code and check which kernels are invoked.

I tried to build my own relu lstm here. could u have a look for my solution ? Is this build correct ? thank you.

import torch
import torch.nn as nn

class CustomLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CustomLSTMCell, self).__init__()
        self.hidden_size = hidden_size

        # LSTM gates: input, forget, cell, and output
        self.w_i = nn.Linear(input_size + hidden_size, hidden_size)
        self.w_f = nn.Linear(input_size + hidden_size, hidden_size)
        self.w_c = nn.Linear(input_size + hidden_size, hidden_size)
        self.w_o = nn.Linear(input_size + hidden_size, hidden_size)

        # Activation functions
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, h_prev, c_prev):
        combined = torch.cat((x, h_prev), dim=1)

        i_t = self.sigmoid(self.w_i(combined))  # Input gate
        f_t = self.sigmoid(self.w_f(combined))  # Forget gate
        c_tilde = self.relu(self.w_c(combined)) # ReLU activation instead of tanh
        c_t = f_t * c_prev + i_t * c_tilde      # Cell state update
        o_t = self.sigmoid(self.w_o(combined))  # Output gate
        h_t = o_t * self.relu(c_t)              # New hidden state using ReLU

        return h_t, c_t


class CustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(CustomLSTM, self).__init__()
        self.num_layers = num_layers

        # Create multiple LSTM layers
        self.lstm_cells = nn.ModuleList([
            CustomLSTMCell(input_size if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])

    def forward(self, x, h_prev, c_prev):
        h_next_list, c_next_list = [], []

        for i in range(self.num_layers):
            h_next, c_next = self.lstm_cells[i](x, h_prev[i], c_prev[i])
            x = h_next  # Pass output to next layer
            h_next_list.append(h_next)
            c_next_list.append(c_next)

        return torch.stack(h_next_list), torch.stack(c_next_list)


# Example usage
input_size = 10
hidden_size = 20
num_layers = 3

lstm_model = CustomLSTM(input_size, hidden_size, num_layers)
x = torch.randn(1, input_size)
h_prev = torch.randn(num_layers, 1, hidden_size)  # Initialize all layers
c_prev = torch.randn(num_layers, 1, hidden_size)

h_next, c_next = lstm_model(x, h_prev, c_prev)

print("Next Hidden States:", h_next)
print("Next Cell States:", c_next)