Coupling Forget Gate and Input Gate of LSTM

Here is my custom implementation of a CIFG in Pytorch :

import math
import torch
from torch import nn

class CIFGCell(nn.Module):
    def __init__(self, inp_dim, hidden_dim):
        super().__init__()

        self.inp_dim = inp_dim
        self.hidden_dim = hidden_dim

        self.w_ih = nn.parameter.Parameter(torch.empty((3 * hidden_dim, inp_dim)))
        self.w_hh = nn.parameter.Parameter(torch.empty((3 * hidden_dim, hidden_dim)))
        self.b_ih = nn.parameter.Parameter(torch.empty(3 * hidden_dim))
        self.b_hh = nn.parameter.Parameter(torch.empty(3 * hidden_dim))

        self.reset_parameters()

    def reset_parameters(self) -> None:
        stdv = 1.0 / math.sqrt(self.hidden_dim) if self.hidden_dim > 0 else 0
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x, hidden=None):
        if hidden is None:
            batch_size = x.size(0)
            hx, cx = torch.zeros(batch_size, self.hidden_dim, dtype=x.dtype, device=x.device), torch.zeros(batch_size, self.hidden_dim, dtype=x.dtype, device=x.device)
        else:
            hx, cx = hidden

        gates = torch.mm(x, self.w_ih.t()) + torch.mm(hx, self.w_hh.t()) + self.b_ih + self.b_hh
        ingate, cellgate, outgate = gates.chunk(3, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = 1 - ingate
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, cy


class CIFG(nn.Module):
    def __init__(self, inp_dim, hidden_dim, n_layers=1):
        super().__init__()

        self.inp_dim = inp_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers

        self.layers = nn.ModuleList([
            CIFGCell(inp_dim, hidden_dim) if i == 0 else CIFGCell(hidden_dim, hidden_dim)
            for i in range(n_layers)
        ])

    def forward(self, x, hidden=None):
        if hidden is None:
            batch_size = x.size(1)
            hx, cx = torch.zeros(self.n_layers, batch_size, self.hidden_dim, dtype=x.dtype, device=x.device), torch.zeros(self.n_layers, batch_size, self.hidden_dim, dtype=x.dtype, device=x.device)
        else:
            hx, cx = hidden

        outputs = []
        for ts in range(x.size(0)):
            hx, cx = self.forward_timestep(x[ts], hx, cx)
            outputs.append(hx[-1])

        return torch.stack(outputs), (hx, cx)

    def forward_timestep(self, x, hx, cx):
        hx = hx.chunk(self.n_layers, 0)
        cx = cx.chunk(self.n_layers, 0)
        prev_output = x
        states = []
        outputs = []
        for i, layer in enumerate(self.layers):
            output, state = layer(prev_output, (hx[i].squeeze(0), cx[i].squeeze(0)))
            outputs.append(output)
            states.append(state)
            prev_output = output

        return torch.stack(outputs), torch.stack(states)