Coupling Forget Gate and Input Gate of LSTM

Tensorflow provides a LSTM cell which couples the forget and input gate but otherwise acts a typical LSTM cell. Tensorflow also provides a GRU cell which is distinct.

I have been looking through the pytorch documentation and cannot seem to find any mention of a lstm cell of this type. Only GRU or LSTM’s but no variations on them.

I know the GRU combines these gates but I want the benefits of longer relation distances and to keep the cell state and hidden state separate.

Is there an equivalent in pytorch to the cell type provided by tensorflow or is this something I will need to build myself?

I would also be interested in this

Any update on this ?
Where can I find an implementation of the CIFG in Pytorch ?

Here is my custom implementation of a CIFG in Pytorch :

import math
import torch
from torch import nn

class CIFGCell(nn.Module):
    def __init__(self, inp_dim, hidden_dim):

        self.inp_dim = inp_dim
        self.hidden_dim = hidden_dim

        self.w_ih = nn.parameter.Parameter(torch.empty((3 * hidden_dim, inp_dim)))
        self.w_hh = nn.parameter.Parameter(torch.empty((3 * hidden_dim, hidden_dim)))
        self.b_ih = nn.parameter.Parameter(torch.empty(3 * hidden_dim))
        self.b_hh = nn.parameter.Parameter(torch.empty(3 * hidden_dim))


    def reset_parameters(self) -> None:
        stdv = 1.0 / math.sqrt(self.hidden_dim) if self.hidden_dim > 0 else 0
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x, hidden=None):
        if hidden is None:
            batch_size = x.size(0)
            hx, cx = torch.zeros(batch_size, self.hidden_dim, dtype=x.dtype, device=x.device), torch.zeros(batch_size, self.hidden_dim, dtype=x.dtype, device=x.device)
            hx, cx = hidden

        gates =, self.w_ih.t()) +, self.w_hh.t()) + self.b_ih + self.b_hh
        ingate, cellgate, outgate = gates.chunk(3, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = 1 - ingate
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, cy

class CIFG(nn.Module):
    def __init__(self, inp_dim, hidden_dim, n_layers=1):

        self.inp_dim = inp_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers

        self.layers = nn.ModuleList([
            CIFGCell(inp_dim, hidden_dim) if i == 0 else CIFGCell(hidden_dim, hidden_dim)
            for i in range(n_layers)

    def forward(self, x, hidden=None):
        if hidden is None:
            batch_size = x.size(1)
            hx, cx = torch.zeros(self.n_layers, batch_size, self.hidden_dim, dtype=x.dtype, device=x.device), torch.zeros(self.n_layers, batch_size, self.hidden_dim, dtype=x.dtype, device=x.device)
            hx, cx = hidden

        outputs = []
        for ts in range(x.size(0)):
            hx, cx = self.forward_timestep(x[ts], hx, cx)

        return torch.stack(outputs), (hx, cx)

    def forward_timestep(self, x, hx, cx):
        hx = hx.chunk(self.n_layers, 0)
        cx = cx.chunk(self.n_layers, 0)
        prev_output = x
        states = []
        outputs = []
        for i, layer in enumerate(self.layers):
            output, state = layer(prev_output, (hx[i].squeeze(0), cx[i].squeeze(0)))
            prev_output = output

        return torch.stack(outputs), torch.stack(states)